Skip to content

Commit

Permalink
Update ifp_egm.md
Browse files Browse the repository at this point in the history
  • Loading branch information
shlff committed Oct 16, 2023
1 parent b248d59 commit 5b3941e
Showing 1 changed file with 28 additions and 42 deletions.
70 changes: 28 additions & 42 deletions lectures/ifp_egm.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.5
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---


# Endogenous Grid Method


Expand Down Expand Up @@ -43,6 +42,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from collections import namedtuple
from interpolation import interp
from numba import njit, float64
from numba.experimental import jitclass
Expand All @@ -60,7 +60,6 @@ We use 64 bit floating point numbers for extra precision.
jax.config.update("jax_enable_x64", True)
```


## Setup

We consider a household that chooses a state-contingent consumption plan $\{c_t\}_{t \geq 0}$ to maximize
Expand Down Expand Up @@ -102,32 +101,35 @@ The following function stores default parameter values for the income
fluctuation problem and creates suitable arrays.

```{code-cell} ipython3
def ifp(R=1.01, # gross interest rate
β=0.99, # discount factor
γ=1.5, # CRRA preference parameter
s_max=16, # savings grid max
s_size=200, # savings grid size
ρ=0.99, # income persistence
ν=0.02, # income volatility
y_size=25): # income grid size
Household = namedtuple('Household', ('β', 'R', 'γ', 's_size', 'y_size', \
's_grid', 'y_grid', 'P'))
def create_household(R=1.01, # gross interest rate
β=0.99, # discount factor
γ=1.5, # CRRA preference parameter
s_max=16, # savings grid max
s_size=200, # savings grid size
ρ=0.99, # income persistence
ν=0.02, # income volatility
y_size=25): # income grid size
# Create income Markov chain
mc = qe.tauchen(y_size, ρ, ν)
y_grid, P = jnp.exp(mc.state_values), mc.P
# Shift to JAX arrays
P, y_grid = jax.device_put((P, y_grid))
s_grid = jnp.linspace(0, s_max, s_size)
sizes = s_size, y_size
s_grid, y_grid, P = jax.device_put((s_grid, y_grid, P))
# require R β < 1 for convergence
assert R * β < 1, "Stability condition violated."
return (β, R, γ), sizes, (s_grid, y_grid, P)
return Household(β=β, R=R, γ=γ, s_size=s_size, y_size=y_size, \
s_grid=s_grid, y_grid=y_grid, P=P)
```


## Solution method

Let $S = \mathbb R_+ \times \mathsf Y$ be the set of possible values for the
Expand Down Expand Up @@ -363,7 +365,6 @@ def K_egm(a_in, σ_in, constants, sizes, arrays):
return a_out, σ_out
```
Then we use `jax.jit` to compile $K$.
We use `static_argnums` to allow a recompile whenever `sizes` changes, since the compiler likes to specialize on shapes.
Expand All @@ -372,7 +373,6 @@ We use `static_argnums` to allow a recompile whenever `sizes` changes, since the
K_egm_jax = jax.jit(K_egm, static_argnums=(3,))
```
Next we define a successive approximator that repeatedly applies $K$.
```{code-cell} ipython3
Expand All @@ -383,11 +383,11 @@ def successive_approx_jax(model,
print_skip=25):
# Unpack
constants, sizes, arrays = model
β, R, γ, s_size, y_size, s_grid, y_grid, P = model
β, R, γ = constants
s_size, y_size = sizes
s_grid, y_grid, P = arrays
constants = β, R, γ
sizes = s_size, y_size
arrays = s_grid, y_grid, P
# Initial condition is to consume all in every state
σ_init = jnp.repeat(s_grid, y_size)
Expand All @@ -414,7 +414,6 @@ def successive_approx_jax(model,
return a_new, σ_new
```
### Numba version
Below we provide a second set of code, which solves the same model with Numba.
Expand All @@ -436,10 +435,9 @@ ifp_data = [
]
# Use the JAX IFP data as our defaults for the Numba version
model = ifp()
constants, sizes, arrays = model
β, R, γ = constants
s_size, y_size = sizes
model = create_household()
β, R, γ, s_size, y_size, s_grid, y_grid, P = model
arrays = s_grid, y_grid, P
s_grid, y_grid, P = (np.array(a) for a in arrays)
@jitclass(ifp_data)
Expand Down Expand Up @@ -539,7 +537,6 @@ def successive_approx_numba(model, # Class with model information
return a_new, σ_new
```
## Solutions
Here we solve the IFP with JAX and Numba.
Expand All @@ -549,22 +546,22 @@ We will compare both the outputs and the execution time.
### Outputs
```{code-cell} ipython3
ifp_jax = ifp()
ifp_jax = create_household()
```
```{code-cell} ipython3
ifp_numba = IFP()
```
Here's a first run of the JAX code.
```{code-cell} ipython3
qe.tic()
a_star_egm_jax, σ_star_egm_jax = successive_approx_jax(ifp_jax,
print_skip=100)
qe.toc()
```
Next let's solve the same IFP with Numba.
```{code-cell} ipython3
Expand All @@ -574,7 +571,6 @@ a_star_egm_nb, σ_star_egm_nb = successive_approx_numba(ifp_numba,
qe.toc()
```
Now let's check the outputs in a plot to make sure they are the same.
```{code-cell} ipython3
Expand All @@ -595,7 +591,6 @@ plt.legend()
plt.show()
```
### Timing
Now let's compare execution time of the two methods
Expand All @@ -618,16 +613,7 @@ numba_time = qe.toc()
jax_time / numba_time
```
The JAX code is significantly faster, as expected.
This difference will increase when more features (and state variables) are added
to the model.
```{code-cell} ipython3
```
```{code-cell} ipython3
```

0 comments on commit 5b3941e

Please sign in to comment.