Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move nested static numba-jit functions #1438

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Smit-create
Copy link
Member

Related to #1419

Here are a few important guidelines and requirements to check before your PR can be merged:

  • There is an informative high-level description of the changes.
  • The description and/or commit message(s) references the relevant GitHub issue(s).
  • pre-commit is installed and set up.
  • The commit messages follow these guidelines.
  • The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR.
  • There are tests covering the changes introduced in the PR.

@Smit-create Smit-create added enhancement New feature or request Numba Involves Numba transpilation labels Feb 17, 2023
Comment on lines 654 to 661
fn = None
# The type can also be RandomType with no ndims
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba_njit(inline="always")
def deepcopyop(x):
return x

fn = deepcopyop_1
else:
fn = deepcopyop_2
return fn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These would probably be better as distinct return statements.

return x


@numba_njit(inline="always")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder: we should revisit our liberal use of inline="always"; it's likely causing us to spend more time compiling things in Numba.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that this approach is bringing out new errors like this because we're causing Numba to dispatch on different types of the x argument in deepcopyop_2 when we were previously creating distinct functions for each argument type. It's possible that we've always been doing something relatively inconsistent this whole time and we're only now seeing it, or it could be a Numba bug.

@Smit-create
Copy link
Member Author

to dispatch on different types of the x argument in deepcopyop_2 when we were previously creating distinct functions for each argument type.

Yes, exactly. numba's jit should be able to handle this by itself each time when it's called. I tried some workarounds but numba fails with the same error as in the CI. Probably a bug on numba side?

I tried to reproduce the bug locally but wasn't successful in reproducing it.

Snippet

from numba import njit
import numpy as np

@njit
def f(x):
    return x

@njit
def g(x):
    return x.copy()

def y(o):
    if o > 1:
        fn = f
    else:
        fn = g
    return fn

print(y(-2)(np.array([4, 5, 6], dtype=int)))
print(y(-2)(np.array([1.0, 2.0, 3.0], dtype=float)))
print(y(2)(np.array([-4, -5, -6], dtype=int)))
print(y(2)(np.array([-1.0, -2.0, -3.0], dtype=float)))

This worked fine for me.

@brandonwillard
Copy link
Member

brandonwillard commented Feb 21, 2023

I tried to reproduce the bug locally but wasn't successful in reproducing it.

I'm able to reproduce the error in CI locally from test_Subtensor on this branch. Try clearing the local __pycache__ directories first.

Comment on lines 639 to 646
@numba_njit(inline="always")
def deepcopyop_1(x):
return x


@numba_njit(inline="always")
def deepcopyop_2(x):
return x.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: If we remove the numba_njit decorators from here and apply them in numba_funcify_DeepCopyOp, the error should go away.

We should probably proceed in that fashion anyway, since some of the options specified in numba_njit are user-configurable within the Python session, and I'm not entirely sure if such changes will be affected if we AOT-like compile them in the Aesara package like this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then we will need to figure out some other way to cache and if the user changes configurations at every run, I think the previous approach would do better.

@Smit-create Smit-create force-pushed the i-1419-1 branch 2 times, most recently from ae4f989 to 53bb10c Compare March 2, 2023 03:19

def make_node_key(node):
"""Create a cache key for `node`.
TODO: Currently this works only with Apply Node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node should only ever be an Apply instance.

Comment on lines 409 to 413
"""Persist a Numba JIT-able Python function.
Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Persist a Numba JIT-able Python function.
Parameters
"""Persist a Numba JIT-able Python function.
Parameters

@Smit-create
Copy link
Member Author

We need to discuss more on this approach and see how the functions are pickled in dill. The test failure on CI is very different from what it is on macOS which is hard to debug. See the failure on Mac M1:

Failure

% pytest tests/link/numba/test_basic.py
============================= test session starts ==============================
platform darwin -- Python 3.10.8, pytest-7.2.1, pluggy-1.0.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/thebigbool/repos/aesara, configfile: pyproject.toml
plugins: xdist-3.1.0, cov-4.0.0, benchmark-4.0.0
collected 73 items                                                             

tests/link/numba/test_basic.py ...................FFFFFFFFFFFFFFFxFFFatal Python error: Segmentation fault

Current thread 0x0000000202e1e2c0 (most recent call first):
  Garbage-collecting
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/ast.py", line 50 in parse
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/source.py", line 185 in getstatementrange_ast
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 263 in getsource
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 722 in _getentrysource
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 814 in repr_traceback_entry
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 871 in repr_traceback
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 944 in repr_excinfo
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/_code/code.py", line 669 in getrepr
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/nodes.py", line 484 in _repr_failure_py
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/python.py", line 1823 in repr_failure
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/reports.py", line 349 in from_item_and_call
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 366 in pytest_runtest_makereport
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 222 in call_and_report
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/Users/thebigbool/opt/anaconda3/envs/aesara-dev-1/bin/pytest", line 10 in <module>

Extension modules: yaml._yaml, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, scipy._lib._ccallback_c, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg._cythonized_array_utils, scipy.linalg._flinalg, scipy.linalg._solve_toeplitz, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_lapack, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse._sparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.special._ellip_harm_2, numpy.linalg.lapack_lite, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, beta_ufunc, scipy.stats._boost.beta_ufunc, binom_ufunc, scipy.stats._boost.binom_ufunc, nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, ncf_ufunc, scipy.stats._boost.ncf_ufunc, ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, nct_ufunc, scipy.stats._boost.nct_ufunc, skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._mvn, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, _cffi_backend (total: 122)
zsh: segmentation fault  pytest tests/link/numba/test_basic.py

@brandonwillard
Copy link
Member

We need to discuss more on this approach and see how the functions are pickled in dill. The test failure on CI is very different from what it is on macOS which is hard to debug. See the failure on Mac M1:

Oh, yeah, that looks like a low-level issue. I'll try it out locally in the next couple of days.

@Smit-create
Copy link
Member Author

Hmm, this gets tricker when we start digging at a low level. IMHO, we can do the following:

  1. Use the previous approach of moving static functions out and for some node-dependent functions, we can try to look into partial functions and their jitting with numba.
  2. This should work fine and perhaps we can shift our focus to fixing this: Investigate use of numba.config.DISABLE_JIT for coverage approximations #1449
  3. We can then try to address the following issue:

We should probably proceed in that fashion anyway, since some of the options specified in numba_njit are user-configurable within the Python session, and I'm not entirely sure if such changes will be affected if we AOT-like compile them in the Aesara package like this.

What do you think @brandonwillard?

@brandonwillard
Copy link
Member

Hmm, this gets tricker when we start digging at a low level. IMHO, we can do the following:

  1. Use the previous approach of moving static functions out and for some node-dependent functions, we can try to look into partial functions and their jitting with numba.
  2. This should work fine and perhaps we can shift our focus to fixing this: Investigate use of numba.config.DISABLE_JIT for coverage approximations #1449
  3. We can then try to address the following issue:

We should probably proceed in that fashion anyway, since some of the options specified in numba_njit are user-configurable within the Python session, and I'm not entirely sure if such changes will be affected if we AOT-like compile them in the Aesara package like this.

What do you think @brandonwillard?

These are all mostly distinct approaches, so they can be done separately.

More specifically, the numba.config.DISABLE_JIT idea is only a way to improve/simplify our workaround for estimating coverage; it shouldn't be blocking anything related to the code generation and caching, unless I've missed something.

@Smit-create
Copy link
Member Author

I just added an option to disable numba cache because while testing, if we continue using the cache which contains the jitted functions, the tests will fail because of type mismatch in numba jitted functions, and also won't provide coverage.
The CI failures now are a couple of test failures (Assertion Errors), probably because of some misinformation while generating the hash-key.

@@ -378,6 +378,13 @@ def add_basic_configvars():
in_c_key=False,
)

config.add(
"DISABLE_NUMBA_CACHE",
("Disable numba caching in the backend"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("Disable numba caching in the backend"),
("Disable caching of the Aesara-generated Python IR used by the Numba backend"),

@@ -378,6 +378,13 @@ def add_basic_configvars():
in_c_key=False,
)

config.add(
"DISABLE_NUMBA_CACHE",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to change this to something like DISABLE_NUMBA_PYTHON_IR_CACHING.


# TODO: Presently numba_py_fn is already jitted.
# numba_fn = numba_njit(numba_py_fn)
return cast(Callable, numba_py_fn)
Copy link
Member

@brandonwillard brandonwillard Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I think the tests are failing because we haven't finished refactoring the rest of the code so that it's aware of numba_funcify now returning un-njited Python functions. For example, test_config_options_cached is expecting numba_mul_fn to be a Numba CPUDispatcher object and not a plain Python function object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important Numba Involves Numba transpilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants