Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hash-cons Applys and Constants #1165

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 1 addition & 3 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ class OpFromGraph(Op, HasInnerGraph):

.. TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try
is_same_graph_with_merge(op1.local_outputs, op2,
local_outputs)
- __hash__, __eq__ otherwise won't merge
- c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface
- add support for NullType and DisconnectedType when R_op supports them
Expand Down
117 changes: 57 additions & 60 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from itertools import chain
from itertools import product as itertools_product
from logging import Logger
from typing import Optional
from typing import TYPE_CHECKING, Optional, Union
from warnings import warn

import numpy as np
from typing_extensions import Literal

import aesara
from aesara.compile.function.types import (
Expand All @@ -42,7 +43,9 @@
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function


__docformat__ = "restructuredtext en"
if TYPE_CHECKING:
from aesara.graph.basic import Apply

_logger: Logger = logging.getLogger("aesara.compile.debugmode")
_logger.addFilter(NoDuplicateOptWarningFilter())

Expand Down Expand Up @@ -1109,43 +1112,32 @@ class _FunctionGraphEvent:

"""

kind = ""
"""
One of 'import', 'change', 'prune'.

"""

node = None
"""
Either 'output' or an Apply instance.

"""

op = None
"""Either 'output' or an Op instance"""
kind: Literal["import", "change", "prune"]
old_node: Optional[Union[Literal["output"], "Apply"]]
new_node: Optional[Union[Literal["output"], "Apply"]]
op: Optional[Union[Literal["output"], Op]]
idx: Optional[int]
reason: Optional[str]

idx = None
"""
Change events involve an position index of the input variable.

"""

reason = None
"""
Change events sometimes have a reason.

"""

def __init__(self, kind, node, idx=None, reason=None):
def __init__(
self,
kind: Literal["import", "change", "prune"],
old_node: Union[Literal["output"], "Apply"],
new_node: Union[Literal["output"], "Apply"] = None,
idx: Optional[int] = None,
reason: Optional[str] = None,
):
self.kind = kind
if node == "output":
self.node = "output"
if old_node == "output":
self.old_node = "output"
self.new_node = "output"
self.op = "output"
else:
self.node = node
self.op = node.op
self.old_node = old_node
self.new_node = new_node
self.op = old_node.op
self.idx = idx
self.reason = str(reason)
self.reason = str(reason) if reason else None

def __str__(self):
if self.kind == "change":
Expand Down Expand Up @@ -1219,21 +1211,21 @@ def on_attach(self, fgraph):
self.replaced_by = {}
self.event_list = []
for node in fgraph.toposort():
self.on_import(fgraph, node, "on_attach")
self.on_import(fgraph, node, reason="on_attach")

def on_detach(self, fgraph):
assert fgraph is self.fgraph
self.fgraph = None

def on_prune(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
self.event_list.append(_FunctionGraphEvent("prune", node, reason=reason))
assert node in self.active_nodes
assert node not in self.inactive_nodes
self.active_nodes.remove(node)
self.inactive_nodes.add(node)

def on_import(self, fgraph, node, reason):
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
self.event_list.append(_FunctionGraphEvent("import", node, reason=reason))

assert node not in self.active_nodes
self.active_nodes.add(node)
Expand All @@ -1253,31 +1245,36 @@ def on_import(self, fgraph, node, reason):
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
def on_change_input(
self, fgraph, old_node, new_node, i, old_var, new_var, reason=None
):
reason = str(reason)
self.event_list.append(
_FunctionGraphEvent("change", node, reason=reason, idx=i)
_FunctionGraphEvent("change", old_node, new_node, idx=i, reason=reason)
)

self.reasons.setdefault(new_r, [])
self.replaced_by.setdefault(new_r, [])
self.on_import(fgraph, new_node, reason=reason)
self.on_prune(fgraph, old_node, reason=reason)

self.reasons.setdefault(new_var, [])
self.replaced_by.setdefault(new_var, [])

append_reason = True
for tup in self.reasons[new_r]:
if tup[0] == reason and tup[1] is r:
for tup in self.reasons[new_var]:
if tup[0] == reason and tup[1] is old_var:
append_reason = False

if append_reason:
# N.B. compute the debugprint now, because future
# optimizations will change the graph
done = dict()
used_ids = dict()
self.reasons[new_r].append(
self.reasons[new_var].append(
(
reason,
r,
old_var,
_debugprint(
r,
old_var,
prefix=" ",
depth=6,
file=StringIO(),
Expand All @@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
used_ids=used_ids,
).getvalue(),
_debugprint(
new_r,
new_var,
prefix=" ",
depth=6,
file=StringIO(),
Expand All @@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
).getvalue(),
)
)
self.replaced_by[r].append((reason, new_r))
self.replaced_by[old_var].append((reason, new_var))

if r in self.equiv:
r_set = self.equiv[r]
if old_var in self.equiv:
r_set = self.equiv[old_var]
else:
r_set = self.equiv.setdefault(r, {r})
self.all_variables_ever.append(r)
r_set = self.equiv.setdefault(old_var, {old_var})
self.all_variables_ever.append(old_var)

if new_r in self.equiv:
new_r_set = self.equiv[new_r]
if new_var in self.equiv:
new_r_set = self.equiv[new_var]
else:
new_r_set = self.equiv.setdefault(new_r, {new_r})
self.all_variables_ever.append(new_r)
new_r_set = self.equiv.setdefault(new_var, {new_var})
self.all_variables_ever.append(new_var)

assert new_r in new_r_set
assert r in r_set
assert new_var in new_r_set
assert old_var in r_set

# update one equivalence set to contain the other
# transfer all the elements of the old one to the new one
Expand All @@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
self.equiv[like_new_r] = r_set
assert like_new_r in r_set

assert self.equiv[r] is r_set
assert self.equiv[new_r] is r_set
assert self.equiv[old_var] is r_set
assert self.equiv[new_var] is r_set

def printstuff(self):
for key in self.equiv:
Expand Down
Loading