Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Type with Python's type system #1207

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
02f1435
Turn `HasDataType` and `HasShape` into `Protocol`s
markusschmaus Sep 23, 2022
0929c9d
Prepare switch from class to metaclass for `Type`
markusschmaus Sep 24, 2022
b1d4098
Convert `__init__`s to only accept keyword arguments
markusschmaus Sep 25, 2022
c99a120
Stop renaming variables in Scan's merge rewrite
brandonwillard Sep 28, 2022
600a118
Allow shared RandomState/Generator updates in JAX compiled functions
ricardoV94 Sep 26, 2022
3125091
Fix _debugprint's handling of empty profile data
brandonwillard Sep 30, 2022
bc2f292
Make `Type` a subclass of `type`
markusschmaus Oct 3, 2022
0813e34
Make compilation modes configurable in compare_numba_and_py
brandonwillard Sep 27, 2022
21c46ac
Enable Numba bounds checking during testing
brandonwillard Sep 21, 2022
47ecc5f
Use to_scalar in numba_funcify_ScalarFromTensor
brandonwillard Sep 27, 2022
ff1ad55
Allow FunctionGraph arguments in create_numba_signature
brandonwillard Sep 23, 2022
26bcc9b
Stop using auto_name for transpilation
brandonwillard Sep 27, 2022
8e67e04
Do not overwrite arguments in Numba's Scan implementation
brandonwillard Sep 21, 2022
e7b8e9b
Fix storage handling in numba_funcify_Scan
brandonwillard Sep 21, 2022
9d07ce9
Make fgraph_to_python process constant FunctionGraph outputs correctly
brandonwillard Oct 4, 2022
a25afa3
Add gufunc signature to `RandomVariable`\'s docstrings
rlouf Aug 31, 2022
6302849
Update pre-commit hooks
brandonwillard Oct 7, 2022
6198d7a
Fix make_numba_random_fn RandomStateType check
brandonwillard Oct 6, 2022
2799bad
Support updates and more input types in compare_numba_and_py
brandonwillard Oct 7, 2022
41eeeb6
Add support for shared inputs in numba_funcify_Scan
brandonwillard Oct 6, 2022
4ef874c
change `issubtype` to `isinstance` on metaclass
markusschmaus Oct 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 22.10.0
hooks:
- id: black
language_version: python3
Expand All @@ -47,7 +47,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.971
rev: v0.982
hooks:
- id: mypy
additional_dependencies:
Expand Down
2 changes: 1 addition & 1 deletion aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_scalar_constant_value(v):
"""
# Is it necessary to test for presence of aesara.sparse at runtime?
sparse = globals().get("sparse")
if sparse and isinstance(v.type, sparse.SparseTensorType):
if sparse and isinstance(v.type, sparse.SparseTensorTypeMeta):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_scalar_constant_value(data)
Expand Down
2 changes: 1 addition & 1 deletion aesara/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def perform(self, node, inputs, output_storage):
output_storage[i][0] = inputs[i + 1]

def grad(self, inputs, output_gradients):
return [DisconnectedType()()] + output_gradients
return [DisconnectedType.subtype()()] + output_gradients

def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
Expand Down
24 changes: 12 additions & 12 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aesara.compile.mode import optdb
from aesara.compile.sharedvalue import SharedVariable
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, Rop, grad
from aesara.gradient import DisconnectedTypeMeta, Rop, grad
from aesara.graph.basic import (
Apply,
Constant,
Expand All @@ -22,7 +22,7 @@
replace_nominals_with_dummies,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.null_type import NullTypeMeta
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.rewriting.basic import in2out, node_rewriter
from aesara.graph.utils import MissingInputError
Expand Down Expand Up @@ -210,7 +210,7 @@ def _filter_grad_var(grad, inp):
#
# For now, this converts NullType or DisconnectedType into zeros_like.
# other types are unmodified: overrider_var -> None
if isinstance(grad.type, (NullType, DisconnectedType)):
if isinstance(grad.type, (NullTypeMeta, DisconnectedTypeMeta)):
if hasattr(inp, "zeros_like"):
return inp.zeros_like(), grad
else:
Expand All @@ -221,9 +221,9 @@ def _filter_grad_var(grad, inp):
@staticmethod
def _filter_rop_var(inpJ, out):
# mostly similar to _filter_grad_var
if isinstance(inpJ.type, NullType):
if isinstance(inpJ.type, NullTypeMeta):
return out.zeros_like(), inpJ
if isinstance(inpJ.type, DisconnectedType):
if isinstance(inpJ.type, DisconnectedTypeMeta):
# since R_op does not have DisconnectedType yet, we will just
# make them zeros.
return out.zeros_like(), None
Expand Down Expand Up @@ -502,7 +502,7 @@ def lop_op(inps, grads):
all_grads_l = list(all_grads_l)
all_grads_ov_l = list(all_grads_ov_l)
elif isinstance(lop_op, Variable):
if isinstance(lop_op.type, (DisconnectedType, NullType)):
if isinstance(lop_op.type, (NullTypeMeta, DisconnectedTypeMeta)):
all_grads_l = [inp.zeros_like() for inp in local_inputs]
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
else:
Expand All @@ -529,7 +529,7 @@ def lop_op(inps, grads):
all_grads_l.append(gnext)
all_grads_ov_l.append(gnext_ov)
elif isinstance(fn_gov, Variable):
if isinstance(fn_gov.type, (DisconnectedType, NullType)):
if isinstance(fn_gov.type, (NullTypeMeta, DisconnectedTypeMeta)):
all_grads_l.append(inp.zeros_like())
all_grads_ov_l.append(fn_gov.type())
else:
Expand Down Expand Up @@ -614,10 +614,10 @@ def _recompute_rop_op(self):
all_rops_l = list(all_rops_l)
all_rops_ov_l = list(all_rops_ov_l)
elif isinstance(rop_op, Variable):
if isinstance(rop_op.type, NullType):
if isinstance(rop_op.type, NullTypeMeta):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [rop_op.type() for _ in range(out_len)]
elif isinstance(rop_op.type, DisconnectedType):
elif isinstance(rop_op.type, DisconnectedTypeMeta):
all_rops_l = [inp.zeros_like() for inp in local_inputs]
all_rops_ov_l = [None] * out_len
else:
Expand All @@ -644,10 +644,10 @@ def _recompute_rop_op(self):
all_rops_l.append(rnext)
all_rops_ov_l.append(rnext_ov)
elif isinstance(fn_rov, Variable):
if isinstance(fn_rov.type, NullType):
if isinstance(fn_rov.type, NullTypeMeta):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(fn_rov.type())
if isinstance(fn_rov.type, DisconnectedType):
if isinstance(fn_rov.type, DisconnectedTypeMeta):
all_rops_l.append(out.zeros_like())
all_rops_ov_l.append(None)
else:
Expand Down Expand Up @@ -857,7 +857,7 @@ def connection_pattern(self, node):
# cpmat_self &= out_is_disconnected
for i, t in enumerate(self._lop_op_stypes_l):
if t is not None:
if isinstance(t.type, DisconnectedType):
if isinstance(t.type, DisconnectedTypeMeta):
for o in range(out_len):
cpmat_self[i][o] = False
for o in range(out_len):
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/compiledir.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from aesara.configdefaults import config
from aesara.graph.op import Op
from aesara.link.c.type import CType
from aesara.link.c.type import CType, CTypeMeta
from aesara.utils import flatten


Expand Down Expand Up @@ -131,7 +131,7 @@ def print_compiledir_content():
zeros_op += 1
else:
types = list(
{x for x in flatten(keydata.keys) if isinstance(x, CType)}
{x for x in flatten(keydata.keys) if isinstance(x, CTypeMeta)}
)
compile_start = compile_end = float("nan")
for fn in os.listdir(os.path.join(compiledir, dir)):
Expand Down
13 changes: 7 additions & 6 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from aesara.link.utils import map_storage, raise_with_op
from aesara.printing import _debugprint
from aesara.tensor import TensorType
from aesara.tensor.type import TensorTypeMeta
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function


Expand Down Expand Up @@ -792,7 +793,7 @@ def _get_preallocated_maps(
for r in considered_outputs:
# There is no risk to overwrite inputs, since r does not work
# inplace.
if isinstance(r.type, TensorType):
if isinstance(r.type, TensorTypeMeta):
reuse_outputs[r][...] = np.asarray(def_val).astype(r.type.dtype)

if reuse_outputs:
Expand All @@ -805,7 +806,7 @@ def _get_preallocated_maps(
if "c_contiguous" in prealloc_modes or "ALL" in prealloc_modes:
c_cont_outputs = {}
for r in considered_outputs:
if isinstance(r.type, TensorType):
if isinstance(r.type, TensorTypeMeta):
# Build a C-contiguous buffer
new_buf = r.type.value_zeros(r_vals[r].shape)
assert new_buf.flags["C_CONTIGUOUS"]
Expand All @@ -822,7 +823,7 @@ def _get_preallocated_maps(
if "f_contiguous" in prealloc_modes or "ALL" in prealloc_modes:
f_cont_outputs = {}
for r in considered_outputs:
if isinstance(r.type, TensorType):
if isinstance(r.type, TensorTypeMeta):
new_buf = np.zeros(
shape=r_vals[r].shape, dtype=r_vals[r].dtype, order="F"
)
Expand Down Expand Up @@ -850,7 +851,7 @@ def _get_preallocated_maps(
max_ndim = 0
rev_out_broadcastable = []
for r in considered_outputs:
if isinstance(r.type, TensorType):
if isinstance(r.type, TensorTypeMeta):
if max_ndim < r.ndim:
rev_out_broadcastable += [True] * (r.ndim - max_ndim)
max_ndim = r.ndim
Expand All @@ -865,7 +866,7 @@ def _get_preallocated_maps(
# Initial allocation
init_strided = {}
for r in considered_outputs:
if isinstance(r.type, TensorType):
if isinstance(r.type, TensorTypeMeta):
# Create a buffer twice as large in every dimension,
# except if broadcastable, or for dimensions above
# config.DebugMode__check_preallocated_output_ndim
Expand Down Expand Up @@ -944,7 +945,7 @@ def _get_preallocated_maps(
name = f"wrong_size{tuple(shape_diff)}"

for r in considered_outputs:
if isinstance(r.type, TensorType):
if isinstance(r.type, TensorTypeMeta):
r_shape_diff = shape_diff[: r.ndim]
out_shape = [
max((s + sd), 0)
Expand Down
4 changes: 2 additions & 2 deletions aesara/compile/nanguardmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import aesara
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.tensor.random.type import RandomTypeMeta
from aesara.tensor.type import discrete_dtypes


Expand All @@ -30,13 +31,12 @@ def _is_numeric_value(arr, var):

"""
from aesara.link.c.type import _cdata_type
from aesara.tensor.random.type import RandomType

if isinstance(arr, _cdata_type):
return False
elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
return False
elif var and isinstance(var.type, RandomType):
elif var and isinstance(var.type, RandomTypeMeta):
return False
elif isinstance(arr, slice):
return False
Expand Down
10 changes: 5 additions & 5 deletions aesara/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.link.c.op import COp
from aesara.link.c.type import CType
from aesara.link.c.type import CTypeMeta


def register_view_op_c_code(type, code, version=()):
Expand Down Expand Up @@ -64,7 +64,7 @@ def c_code(self, node, nodename, inp, out, sub):
(oname,) = out
fail = sub["fail"]

itype = node.inputs[0].type.__class__
itype = node.inputs[0].type.base_type
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
return code % locals()
Expand Down Expand Up @@ -199,7 +199,7 @@ def c_code(self, node, name, inames, onames, sub):
(oname,) = onames
fail = sub["fail"]

itype = node.inputs[0].type.__class__
itype = node.inputs[0].type.base_type
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
return code % locals()
Expand Down Expand Up @@ -311,11 +311,11 @@ def numpy_dot(a, b):
"""
if not isinstance(itypes, (list, tuple)):
itypes = [itypes]
if not all(isinstance(t, CType) for t in itypes):
if not all(isinstance(t, CTypeMeta) for t in itypes):
raise TypeError("itypes has to be a list of Aesara types")
if not isinstance(otypes, (list, tuple)):
otypes = [otypes]
if not all(isinstance(t, CType) for t in otypes):
if not all(isinstance(t, CTypeMeta) for t in otypes):
raise TypeError("otypes has to be a list of Aesara types")

# make sure they are lists and not tuples
Expand Down
Loading