From edb83f105d6f5b07cdecfd2624b82b32477b29db Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 6 Sep 2024 14:02:08 +0000 Subject: [PATCH] compiler: Patch mapify-reduce for SparseTimeFunctions --- devito/ir/clusters/algorithms.py | 20 +++++++++++++++----- devito/ir/equations/equation.py | 22 ++++++++++++++++++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index e4e3c9deac..b463593438 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -9,7 +9,7 @@ from devito.finite_differences.elementary import Max, Min from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange, pull_dims, null_ispace) -from devito.ir.equations import OpMin, OpMax +from devito.ir.equations import OpMin, OpMax, identity_mapper from devito.ir.clusters.analysis import analyze from devito.ir.clusters.cluster import Cluster, ClusterGroup from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass @@ -580,8 +580,8 @@ def normalize_reductions_minmax(cluster): elif e.operation is OpMax: if not f.is_Input: expr = Eq(lhs, limits_mapper[lhs.dtype].min) - ispce = cluster.ispace.project(lambda i: i not in dims) - init.append(cluster.rebuild(exprs=expr, ispace=ispce)) + ispace = cluster.ispace.project(lambda i: i not in dims) + init.append(cluster.rebuild(exprs=expr, ispace=ispace)) processed.append(e.func(lhs, Max(lhs, rhs))) @@ -663,8 +663,18 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform): a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims, grid=grid) - processed.extend([Eq(a.indexify(), rhs), - e.func(lhs, a.indexify())]) + # Populate the Array (the "map" part) + processed.append(e.func(a.indexify(), rhs, operation=None)) + + # Set all untouched entried to the identity value if necessary + if e.conditionals: + nc = {d: sympy.Not(v) for d, v in e.conditionals.items()} + v = identity_mapper[e.lhs.dtype][e.operation] + processed.append( + e.func(a.indexify(), v, operation=None, conditionals=nc) + ) + + processed.append(e.func(lhs, a.indexify())) for d in sequentialize: properties = properties.sequentialize(d) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index b5f8b6979f..ada1c23f22 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -1,16 +1,18 @@ from functools import cached_property +import numpy as np import sympy from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.finite_differences.differentiable import diff2sympy from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_io, detect_accesses) -from devito.symbolics import IntDiv, uxreplace +from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min -__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax'] +__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax', + 'identity_mapper'] class IREq(sympy.Eq, Pickable): @@ -119,6 +121,22 @@ def detect(cls, expr): OpMin = Operation('min') +identity_mapper = { + np.int32: {OpInc: sympy.S.Zero, + OpMax: limits_mapper[np.int32].min, + OpMin: limits_mapper[np.int32].max}, + np.int64: {OpInc: sympy.S.Zero, + OpMax: limits_mapper[np.int64].min, + OpMin: limits_mapper[np.int64].max}, + np.float32: {OpInc: sympy.S.Zero, + OpMax: limits_mapper[np.float32].min, + OpMin: limits_mapper[np.float32].max}, + np.float64: {OpInc: sympy.S.Zero, + OpMax: limits_mapper[np.float64].min, + OpMin: limits_mapper[np.float64].max}, +} + + class LoweredEq(IREq): """