Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove add_mul_fusion rewrite #1243

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aesara/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,16 +896,16 @@ def print_profile(cls, stream, prof, level=0):
print(blanc, " time_toposort", prof[7], file=stream)


fuse_seqopt = SequenceDB()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was potentially buggy, as fuse_seqopt was always being imported by the other module, even though it's creation was conditional on a config flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we get rid of it if we have only a single rewrite now?

if config.tensor__local_elemwise_fusion:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought enough about this just yet, but it really looks like we need to consider removing this option entirely, because, unless we redesign the whole approach to fusion/FusionOptimizer, I don't know whether this idea of applying a graph (i.e. "global") rewriter on each node (i.e. "locally") makes any sense.

I could see a place for a distinct node-based fusion rewrite, but I really don't think it would/should operate anything like a graph-based rewrite (e.g. it seems like there would be way too much unnecessary graph walking).

In general, the idea of a node-based fusion rewrite sounds appealing, because it could be a lot simpler than the/a graph-based one, but I don't know which types of graphs the former would miss that the latter wouldn't.

# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
"fast_run",
"fusion",
position=1,
)
# Position before AddDestroyHandler(49.5)
compile.optdb.register( # type: ignore
"elemwise_fusion",
fuse_seqopt,
Expand Down
61 changes: 0 additions & 61 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
register_uncanonicalize,
register_useless,
)
from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
Expand Down Expand Up @@ -2843,66 +2842,6 @@ def check_input(inputs):
return [ret]


def local_add_mul_fusion(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as
this make the inner graph of the Composite smaller. This allow to
put more computation in a Composite before hitting the max
recursion limit when pickling Composite.
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.Add, aes.Mul)
):
return False

s_op = node.op.scalar_op.__class__
new_inp = []
fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op)
and
# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
):
new_inp.extend(inp.owner.inputs)
fused = True
else:
new_inp.append(inp)

# We can not compare the number of inputs as Mul and Add could have
# 0 or 1 inputs in some corner cases.
if fused:
output = node.op(*new_inp)
copy_stack_trace(node.outputs[0], output)

# Do the recursion here to help lower the number of
# FusionOptimizer iteration.
if output.owner:
output2 = local_add_mul_fusion(fgraph, output.owner)
if output2:
return output2
return [output]


fuse_seqopt.register(
"local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion),
"fast_run",
"fusion",
position=0,
)


def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
Expand Down
44 changes: 29 additions & 15 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,23 +998,35 @@ def test_big_fusion(self):
for node in dlogp.maker.fgraph.toposort()
)

def test_add_mul_fusion_inplace(self):

rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)

mode = Mode(self.mode.linker, rewrites)
def test_add_mul_fusion_precedence(self):
"""Test that additions and multiplications are "fused together" before
a `Composite` `Op` is introduced. This fusion is done by canonicalization
"""
x, y, z = vectors("x", "y", "z")
out = log((x + y + z) / (x * y * z))
f = aesara.function([x, y, z], out, mode=self.mode)
# There should be a single Composite Op
nodes = f.maker.fgraph.apply_nodes
assert len(nodes) == 1
(node,) = nodes
assert isinstance(node.op, Elemwise)
scalar_op = node.op.scalar_op
assert isinstance(scalar_op, Composite)
assert [node.op for node in scalar_op.fgraph.toposort()] == [
# There should be a single mul
aes.mul,
# There should be a single add
aes.add,
aes.true_div,
aes.log,
]

def test_add_mul_fusion_inplace(self):
# Note: This has nothing to do with the FusionOptimizer, as the "fusion"
# is done by canonicalize
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=mode)
f = function([x, y, z], out, mode=self.mode)
topo = [n for n in f.maker.fgraph.toposort()]
assert len(topo) == 2
assert topo[-1].op.inplace_pattern
Expand All @@ -1026,7 +1038,9 @@ def test_add_mul_fusion_inplace(self):

# TODO: Do we really need to do this?
_ = f(
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
np.random.random((5, 5)),
np.random.random((5, 5)),
np.random.random((5, 5)),
)

@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
Expand Down
Loading