-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move nested static numba-jit functions #1438
base: main
Are you sure you want to change the base?
Conversation
aesara/link/numba/dispatch/basic.py
Outdated
fn = None | ||
# The type can also be RandomType with no ndims | ||
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0: | ||
# TODO: Do we really need to compile a pass-through function like this? | ||
@numba_njit(inline="always") | ||
def deepcopyop(x): | ||
return x | ||
|
||
fn = deepcopyop_1 | ||
else: | ||
fn = deepcopyop_2 | ||
return fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These would probably be better as distinct return
statements.
aesara/link/numba/dispatch/basic.py
Outdated
return x | ||
|
||
|
||
@numba_njit(inline="always") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder: we should revisit our liberal use of inline="always"
; it's likely causing us to spend more time compiling things in Numba.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing that this approach is bringing out new errors like this because we're causing Numba to dispatch on different types of the x
argument in deepcopyop_2
when we were previously creating distinct functions for each argument type. It's possible that we've always been doing something relatively inconsistent this whole time and we're only now seeing it, or it could be a Numba bug.
Yes, exactly. I tried to reproduce the bug locally but wasn't successful in reproducing it. Snippet
from numba import njit
import numpy as np
@njit
def f(x):
return x
@njit
def g(x):
return x.copy()
def y(o):
if o > 1:
fn = f
else:
fn = g
return fn
print(y(-2)(np.array([4, 5, 6], dtype=int)))
print(y(-2)(np.array([1.0, 2.0, 3.0], dtype=float)))
print(y(2)(np.array([-4, -5, -6], dtype=int)))
print(y(2)(np.array([-1.0, -2.0, -3.0], dtype=float))) This worked fine for me. |
I'm able to reproduce the error in CI locally from |
aesara/link/numba/dispatch/basic.py
Outdated
@numba_njit(inline="always") | ||
def deepcopyop_1(x): | ||
return x | ||
|
||
|
||
@numba_njit(inline="always") | ||
def deepcopyop_2(x): | ||
return x.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: If we remove the numba_njit
decorators from here and apply them in numba_funcify_DeepCopyOp
, the error should go away.
We should probably proceed in that fashion anyway, since some of the options specified in numba_njit
are user-configurable within the Python session, and I'm not entirely sure if such changes will be affected if we AOT-like compile them in the Aesara package like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Then we will need to figure out some other way to cache and if the user changes configurations at every run, I think the previous approach would do better.
ae4f989
to
53bb10c
Compare
|
||
def make_node_key(node): | ||
"""Create a cache key for `node`. | ||
TODO: Currently this works only with Apply Node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
node
should only ever be an Apply
instance.
aesara/link/numba/dispatch/basic.py
Outdated
"""Persist a Numba JIT-able Python function. | ||
Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Persist a Numba JIT-able Python function. | |
Parameters | |
"""Persist a Numba JIT-able Python function. | |
Parameters |
We need to discuss more on this approach and see how the functions are pickled in Failure
% pytest tests/link/numba/test_basic.py
============================= test session starts ==============================
platform darwin -- Python 3.10.8, pytest-7.2.1, pluggy-1.0.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/thebigbool/repos/aesara, configfile: pyproject.toml
plugins: xdist-3.1.0, cov-4.0.0, benchmark-4.0.0
collected 73 items
tests/link/numba/test_basic.py ...................FFFFFFFFFFFFFFFxFFFatal Python error: Segmentation fault
Current thread 0x0000000202e1e2c0 (most recent call first):
Garbage-collecting
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/ast.py", line 50 in parse
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/source.py", line 185 in getstatementrange_ast
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 263 in getsource
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 722 in _getentrysource
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 814 in repr_traceback_entry
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 871 in repr_traceback
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 944 in repr_excinfo
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 669 in getrepr
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/nodes.py", line 484 in _repr_failure_py
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/python.py", line 1823 in repr_failure
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/reports.py", line 349 in from_item_and_call
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 366 in pytest_runtest_makereport
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 222 in call_and_report
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/bin/pytest", line 10 in <module>
Extension modules: yaml._yaml, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, scipy._lib._ccallback_c, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg._cythonized_array_utils, scipy.linalg._flinalg, scipy.linalg._solve_toeplitz, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_lapack, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse._sparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.special._ellip_harm_2, numpy.linalg.lapack_lite, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, beta_ufunc, scipy.stats._boost.beta_ufunc, binom_ufunc, scipy.stats._boost.binom_ufunc, nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, ncf_ufunc, scipy.stats._boost.ncf_ufunc, ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, nct_ufunc, scipy.stats._boost.nct_ufunc, skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._mvn, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, _cffi_backend (total: 122)
zsh: segmentation fault pytest tests/link/numba/test_basic.py |
Oh, yeah, that looks like a low-level issue. I'll try it out locally in the next couple of days. |
Hmm, this gets tricker when we start digging at a low level. IMHO, we can do the following:
What do you think @brandonwillard? |
These are all mostly distinct approaches, so they can be done separately. More specifically, the |
I just added an option to disable |
aesara/configdefaults.py
Outdated
@@ -378,6 +378,13 @@ def add_basic_configvars(): | |||
in_c_key=False, | |||
) | |||
|
|||
config.add( | |||
"DISABLE_NUMBA_CACHE", | |||
("Disable numba caching in the backend"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
("Disable numba caching in the backend"), | |
("Disable caching of the Aesara-generated Python IR used by the Numba backend"), |
aesara/configdefaults.py
Outdated
@@ -378,6 +378,13 @@ def add_basic_configvars(): | |||
in_c_key=False, | |||
) | |||
|
|||
config.add( | |||
"DISABLE_NUMBA_CACHE", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to change this to something like DISABLE_NUMBA_PYTHON_IR_CACHING
.
|
||
# TODO: Presently numba_py_fn is already jitted. | ||
# numba_fn = numba_njit(numba_py_fn) | ||
return cast(Callable, numba_py_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: I think the tests are failing because we haven't finished refactoring the rest of the code so that it's aware of numba_funcify
now returning un-njit
ed Python functions. For example, test_config_options_cached
is expecting numba_mul_fn
to be a Numba CPUDispatcher
object and not a plain Python function object.
Related to #1419
Here are a few important guidelines and requirements to check before your PR can be merged:
pre-commit
is installed and set up.