Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prange doesn't work in overloaded functions #116

Open
landmanbester opened this issue Sep 23, 2024 · 1 comment
Open

prange doesn't work in overloaded functions #116

landmanbester opened this issue Sep 23, 2024 · 1 comment

Comments

@landmanbester
Copy link
Collaborator

Trying to use prange inside an overload eg.

@njit(**JIT_OPTIONS, parallel=True)
def update(x, xp, r, rp, p, Ap, alpha):
    return update_impl(x, xp, r, rp, p, Ap, alpha)


def update_impl(x, xp, r, rp, p, Ap, alpha):
    return NotImplementedError


@overload(update_impl, jit_options=JIT_OPTIONS, parallel=True)
def nb_update_impl(x, xp, r, rp, p, Ap, alpha):
    if x.ndim==3:
        def impl(x, xp, r, rp, p, Ap, alpha):
            nband, nx, ny = x.shape
            for b in range(nband):
                for i in prange(nx):
                    for j in range(ny):
                        x[b, i, j] = xp[b, i, j] + alpha * p[b, i, j]
                        r[b, i, j] = rp[b, i, j] + alpha * Ap[b, i, j]
            return x, r
    elif x.ndim==2:
        def impl(x, xp, r, rp, p, Ap, alpha):
            nx, ny = x.shape
            for i in prange(nx):
                for j in range(ny):
                    x[i, j] = xp[i, j] + alpha * p[i, j]
                    r[i, j] = rp[i, j] + alpha * Ap[i, j]
            return x, r
    else:
        raise ValueError("update only implemented for 2D or 3D arrays")

    return impl

results in the following warning message during compilation

/home/bester/.venv/pfb/lib/python3.10/site-packages/numba/core/typed_passes.
py:336: NumbaPerformanceWarning:
The keyword argument 'parallel=True' was specified but no transformation for
 parallel execution was possible.

To find out why, try turning on parallel diagnostics, see https://numba.read
thedocs.io/en/stable/user/parallel.html#diagnostics for help.

File "../../software/pfb-imaging/pfb/opt/pcg.py", line 21:
@njit(**JIT_OPTIONS, parallel=True)
def update(x, xp, r, rp, p, Ap, alpha):
^

I've seen this before when doing nested function calls to to prange (eg. here). For this function I get the same warning but I do actually see multiple threads spinning up whereas the overloaded implementation doesn't seem to parallelize at all. I wonder if this is a bug in numba or if I'm trying something that is not supported

@landmanbester
Copy link
Collaborator Author

Looking at the parallel diagnostics gives

In [8]: update.parallel_diagnostics()

================================================================================
 Parallel Accelerator Optimizing:  Function update, /home/bester/software/pfb-
imaging/pfb/opt/pcg.py (20)
================================================================================
No source available
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant