Skip to content

Commit

Permalink
CI: add test for and fixes #920
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 17, 2023
1 parent 5312b30 commit d1d0369
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
8 changes: 1 addition & 7 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import singledispatch

import numpy as np
from sympy import Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple, Add
from sympy import Function, Indexed, Integer, Mul, Number, Pow, S, Symbol, Tuple

from devito.finite_differences import Derivative
from devito.finite_differences.differentiable import IndexDerivative
Expand Down Expand Up @@ -269,12 +269,6 @@ def sympy_dtype(expr, default):
returns the default if non is found
"""
args = expr.args
# We can only infer the dtype for addition/multiplication or Symbols
# For other case the epxression function may modify the infered dtype
if not (isinstance(expr.func, Add) or isinstance(expr.func, Mul)) or \
not expr.is_Symbol:
return default

# Symbol/... without argument, check its dtype
if len(args) == 0:
try:
Expand Down
5 changes: 4 additions & 1 deletion devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,10 @@ def _dist_data_gather(self, data):
return

# Compute dist map only once
data = self._C_as_ndarray(data)
try:
data = self._C_as_ndarray(data)
except AttributeError:
pass
dmap = self._dist_datamap
mask = self._dist_scatter_mask(dmap=dmap)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,21 @@ def test_premature_evalderiv_lowering(self):
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == 1
assert op._profiler._sections['section0'].sops == 16

def test_dtype_aliases(self):
a = np.arange(64).reshape((8, 8))
grid = Grid(shape=a.shape, extent=(8, 8))

so = 2
f = Function(name='f', grid=grid, space_order=so, dtype=np.int32)
f.data[:] = a

fo = Function(name='fo', grid=grid, space_order=so, dtype=np.int32)
op = Operator(Eq(fo, f.dx))
op.apply()

assert FindNodes(Expression).visit(op)[0].dtype == np.float32
assert np.all(fo.data[:-1, :-1] == 6)


class TestIsoAcoustic(object):

Expand Down
2 changes: 1 addition & 1 deletion tests/test_gpu_openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_tile_insteadof_collapse(self, par_tile):
opt=('advanced', {'par-tile': par_tile}))

trees = retrieve_iteration_tree(op)
assert len(trees) == 4
assert len(trees) == 6

assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
Expand Down

0 comments on commit d1d0369

Please sign in to comment.