Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Update Numba Lecture to Address Deprecation of @jit #296

Merged
merged 7 commits into from
Dec 14, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 121 additions & 59 deletions lectures/numba.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.4
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand All @@ -26,10 +28,9 @@ kernelspec:

In addition to what's in Anaconda, this lecture will need the following libraries:

```{code-cell} ipython
---
tags: [hide-output]
---
```{code-cell} ipython3
:tags: [hide-output]

!pip install quantecon
```

Expand All @@ -38,7 +39,7 @@ versions are a {doc}`common source of errors <troubleshooting>`.

Let's start with some imports:

```{code-cell} ipython
```{code-cell} ipython3
%matplotlib inline
import numpy as np
import quantecon as qe
Expand Down Expand Up @@ -98,13 +99,13 @@ $$

In what follows we set

```{code-cell} python3
```{code-cell} ipython3
α = 4.0
```

Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis

```{code-cell} python3
```{code-cell} ipython3
def qm(x0, n):
x = np.empty(n+1)
x[0] = x0
Expand All @@ -122,10 +123,10 @@ plt.show()

To speed the function `qm` up using Numba, our first step is

```{code-cell} python3
from numba import jit
```{code-cell} ipython3
from numba import njit

qm_numba = jit(qm)
qm_numba = njit(qm)
```

The function `qm_numba` is a version of `qm` that is "targeted" for
Expand All @@ -135,7 +136,7 @@ We will explain what this means momentarily.

Let's time and compare identical function calls across these two versions, starting with the original function `qm`:

```{code-cell} python3
```{code-cell} ipython3
n = 10_000_000

qe.tic()
Expand All @@ -145,7 +146,7 @@ time1 = qe.toc()

Now let's try qm_numba

```{code-cell} python3
```{code-cell} ipython3
qe.tic()
qm_numba(0.1, int(n))
time2 = qe.toc()
Expand All @@ -156,13 +157,14 @@ This is already a massive speed gain.
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:

(qm_numba_result)=
```{code-cell} python3

```{code-cell} ipython3
qe.tic()
qm_numba(0.1, int(n))
time3 = qe.toc()
```

```{code-cell} python3
```{code-cell} ipython3
time1 / time3 # Calculate speed gain
```

Expand Down Expand Up @@ -194,12 +196,12 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2

The compiled code is then cached and recycled as required.

## Decorators and "nopython" Mode
## Decorator Notation

In the code above we created a JIT compiled version of `qm` via the call

```{code-cell} python3
qm_numba = jit(qm)
```{code-cell} ipython3
qm_numba = njit(qm)
```

In practice this would typically be done using an alternative *decorator* syntax.
Expand All @@ -208,14 +210,12 @@ In practice this would typically be done using an alternative *decorator* syntax

Let's see how this is done.

### Decorator Notation

To target a function for JIT compilation we can put `@jit` before the function definition.
To target a function for JIT compilation we can put `@njit` before the function definition.

Here's what this looks like for `qm`

```{code-cell} python3
@jit
```{code-cell} ipython3
@njit
def qm(x0, n):
x = np.empty(n+1)
x[0] = x0
Expand All @@ -224,15 +224,21 @@ def qm(x0, n):
return x
```

This is equivalent to `qm = jit(qm)`.
This is equivalent to `qm = njit(qm)`.

The following now uses the jitted version:

```{code-cell} python3
qm(0.1, 10)
```{code-cell} ipython3
%%time

qm(0.1, 100_000)
```

### Type Inference and "nopython" Mode
Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).

In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization.

## Type Inference

Clearly type inference is a key part of JIT compilation.

Expand All @@ -246,29 +252,87 @@ This allows it to generate native machine code, without having to call the Pytho

In such a setting, Numba will be on par with machine code from low-level languages.

When Numba cannot infer all type information, some Python objects are given generic object status and execution falls back to the Python runtime.
When Numba cannot infer all type information, it will raise an error.

When this happens, Numba provides only minor speed gains or none at all.
For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`

We generally prefer to force an error when this occurs, so we know effective
compilation is failing.
```{code-cell} ipython3
@njit
def bootstrap(data, statistics, n):
bootstrap_stat = np.empty(n)
n = len(data)
for i in range(n_resamples):
resample = np.random.choice(data, size=n, replace=True)
bootstrap_stat[i] = statistics(resample)
return bootstrap_stat

This is done by using either `@jit(nopython=True)` or, equivalently, `@njit` instead of `@jit`.
def mean(data):
return np.mean(data)

For example,
data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2])
n_resamples = 10

```{code-cell} python3
from numba import njit
print('Type of function:', type(mean))

#Error
try:
bootstrap(data, mean, n_resamples)
except Exception as e:
print(e)
```

But Numba recognizes JIT-compiled functions

```{code-cell} ipython3
@njit
def qm(x0, n):
x = np.empty(n+1)
x[0] = x0
for t in range(n):
x[t+1] = 4 * x[t] * (1 - x[t])
return x
def mean(data):
return np.mean(data)

print('Type of function:', type(mean))

%time bootstrap(data, mean, n_resamples)
```

We can check the signature of the JIT-compiled function

```{code-cell} ipython3
bootstrap.signatures
```

It shows that the function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer.

Now let's see what happens when we change the inputs.

Running it again with a larger integer for `n` and a different set of data does not change the signature of the function.

```{code-cell} ipython3
data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2])
%time bootstrap(data, mean, 100)
bootstrap.signatures
```

As expected, the second run is much faster.

Let's try to change the data again and use an integer array as data

```{code-cell} ipython3
data = np.array([1, 2, 3, 4, 5], dtype=np.int64)
%time bootstrap(data, mean, 100)
bootstrap.signatures
```

Note that a second signature with an `int64` array as the first argument is added into the signature of `bootstrap` function.

The runtime is slower as if we ran the function for the first time.

It suggests that Numba recompiles this function as the type changes.

Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports as we have shown in the function example.

In fact, this limitation means that Numba does not support everything in Python and scientific libraries as we learned before.
HumphreyYang marked this conversation as resolved.
Show resolved Hide resolved

You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html).

## Compiling Classes

As mentioned above, at present Numba can only compile a subset of Python.
Expand All @@ -285,7 +349,7 @@ created in {doc}`this lecture <python_oop>`.

To compile this class we use the `@jitclass` decorator:

```{code-cell} python3
```{code-cell} ipython3
from numba import float64
from numba.experimental import jitclass
```
Expand All @@ -298,7 +362,7 @@ We are importing it here because Numba needs a bit of extra help with types when

Here's our code:

```{code-cell} python3
```{code-cell} ipython3
solow_data = [
('n', float64),
('s', float64),
Expand Down Expand Up @@ -361,7 +425,7 @@ After that, targeting the class for JIT compilation only requires adding

When we call the methods in the class, the methods are compiled just like functions.

```{code-cell} python3
```{code-cell} ipython3
s1 = Solow()
s2 = Solow(k=8.0)

Expand Down Expand Up @@ -444,25 +508,25 @@ For larger ones, or for routines using external libraries, it can easily fail.

Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code.

This will give you much better performance than blanketing your Python programs with `@jit` statements.
This will give you much better performance than blanketing your Python programs with `@njit` statements.

### A Gotcha: Global Variables

Here's another thing to be careful about when using Numba.

Consider the following example

```{code-cell} python3
```{code-cell} ipython3
a = 1

@jit
@njit
def add_a(x):
return a + x

print(add_a(10))
```

```{code-cell} python3
```{code-cell} ipython3
a = 2

print(add_a(10))
Expand Down Expand Up @@ -492,7 +556,7 @@ Compare speed with and without Numba when the sample size is large.

Here is one solution:

```{code-cell} python3
```{code-cell} ipython3
from random import uniform

@njit
Expand Down Expand Up @@ -581,13 +645,13 @@ We let
- 0 represent "low"
- 1 represent "high"

```{code-cell} python3
```{code-cell} ipython3
p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
```

Here's a pure Python version of the function

```{code-cell} python3
```{code-cell} ipython3
def compute_series(n):
x = np.empty(n, dtype=np.int_)
x[0] = 1 # Start in state 1
Expand All @@ -604,7 +668,7 @@ def compute_series(n):
Let's run this code and check that the fraction of time spent in the low
state is about 0.666

```{code-cell} python3
```{code-cell} ipython3
n = 1_000_000
x = compute_series(n)
print(np.mean(x == 0)) # Fraction of time x is in state 0
Expand All @@ -614,30 +678,28 @@ This is (approximately) the right output.

Now let's time it:

```{code-cell} python3
```{code-cell} ipython3
qe.tic()
compute_series(n)
qe.toc()
```

Next let's implement a Numba version, which is easy

```{code-cell} python3
from numba import jit

compute_series_numba = jit(compute_series)
```{code-cell} ipython3
compute_series_numba = njit(compute_series)
```

Let's check we still get the right numbers

```{code-cell} python3
```{code-cell} ipython3
x = compute_series_numba(n)
print(np.mean(x == 0))
```

Let's see the time

```{code-cell} python3
```{code-cell} ipython3
qe.tic()
compute_series_numba(n)
qe.toc()
Expand Down