From 5b3941eed7039179646e1fdd4659efacfbbca379 Mon Sep 17 00:00:00 2001 From: shlff Date: Mon, 16 Oct 2023 12:29:28 +1100 Subject: [PATCH] Update ifp_egm.md --- lectures/ifp_egm.md | 70 ++++++++++++++++++--------------------------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/lectures/ifp_egm.md b/lectures/ifp_egm.md index cfe9c559..6d5f2937 100644 --- a/lectures/ifp_egm.md +++ b/lectures/ifp_egm.md @@ -4,14 +4,13 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.15.2 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- - # Endogenous Grid Method @@ -43,6 +42,7 @@ import numpy as np import jax import jax.numpy as jnp +from collections import namedtuple from interpolation import interp from numba import njit, float64 from numba.experimental import jitclass @@ -60,7 +60,6 @@ We use 64 bit floating point numbers for extra precision. jax.config.update("jax_enable_x64", True) ``` - ## Setup We consider a household that chooses a state-contingent consumption plan $\{c_t\}_{t \geq 0}$ to maximize @@ -102,32 +101,35 @@ The following function stores default parameter values for the income fluctuation problem and creates suitable arrays. ```{code-cell} ipython3 -def ifp(R=1.01, # gross interest rate - β=0.99, # discount factor - γ=1.5, # CRRA preference parameter - s_max=16, # savings grid max - s_size=200, # savings grid size - ρ=0.99, # income persistence - ν=0.02, # income volatility - y_size=25): # income grid size - +Household = namedtuple('Household', ('β', 'R', 'γ', 's_size', 'y_size', \ + 's_grid', 'y_grid', 'P')) + +def create_household(R=1.01, # gross interest rate + β=0.99, # discount factor + γ=1.5, # CRRA preference parameter + s_max=16, # savings grid max + s_size=200, # savings grid size + ρ=0.99, # income persistence + ν=0.02, # income volatility + y_size=25): # income grid size + # Create income Markov chain mc = qe.tauchen(y_size, ρ, ν) y_grid, P = jnp.exp(mc.state_values), mc.P + # Shift to JAX arrays P, y_grid = jax.device_put((P, y_grid)) s_grid = jnp.linspace(0, s_max, s_size) - sizes = s_size, y_size s_grid, y_grid, P = jax.device_put((s_grid, y_grid, P)) # require R β < 1 for convergence assert R * β < 1, "Stability condition violated." - - return (β, R, γ), sizes, (s_grid, y_grid, P) + + return Household(β=β, R=R, γ=γ, s_size=s_size, y_size=y_size, \ + s_grid=s_grid, y_grid=y_grid, P=P) ``` - ## Solution method Let $S = \mathbb R_+ \times \mathsf Y$ be the set of possible values for the @@ -363,7 +365,6 @@ def K_egm(a_in, σ_in, constants, sizes, arrays): return a_out, σ_out ``` - Then we use `jax.jit` to compile $K$. We use `static_argnums` to allow a recompile whenever `sizes` changes, since the compiler likes to specialize on shapes. @@ -372,7 +373,6 @@ We use `static_argnums` to allow a recompile whenever `sizes` changes, since the K_egm_jax = jax.jit(K_egm, static_argnums=(3,)) ``` - Next we define a successive approximator that repeatedly applies $K$. ```{code-cell} ipython3 @@ -383,11 +383,11 @@ def successive_approx_jax(model, print_skip=25): # Unpack - constants, sizes, arrays = model + β, R, γ, s_size, y_size, s_grid, y_grid, P = model - β, R, γ = constants - s_size, y_size = sizes - s_grid, y_grid, P = arrays + constants = β, R, γ + sizes = s_size, y_size + arrays = s_grid, y_grid, P # Initial condition is to consume all in every state σ_init = jnp.repeat(s_grid, y_size) @@ -414,7 +414,6 @@ def successive_approx_jax(model, return a_new, σ_new ``` - ### Numba version Below we provide a second set of code, which solves the same model with Numba. @@ -436,10 +435,9 @@ ifp_data = [ ] # Use the JAX IFP data as our defaults for the Numba version -model = ifp() -constants, sizes, arrays = model -β, R, γ = constants -s_size, y_size = sizes +model = create_household() +β, R, γ, s_size, y_size, s_grid, y_grid, P = model +arrays = s_grid, y_grid, P s_grid, y_grid, P = (np.array(a) for a in arrays) @jitclass(ifp_data) @@ -539,7 +537,6 @@ def successive_approx_numba(model, # Class with model information return a_new, σ_new ``` - ## Solutions Here we solve the IFP with JAX and Numba. @@ -549,22 +546,22 @@ We will compare both the outputs and the execution time. ### Outputs ```{code-cell} ipython3 -ifp_jax = ifp() +ifp_jax = create_household() ``` ```{code-cell} ipython3 ifp_numba = IFP() ``` - Here's a first run of the JAX code. ```{code-cell} ipython3 +qe.tic() a_star_egm_jax, σ_star_egm_jax = successive_approx_jax(ifp_jax, print_skip=100) +qe.toc() ``` - Next let's solve the same IFP with Numba. ```{code-cell} ipython3 @@ -574,7 +571,6 @@ a_star_egm_nb, σ_star_egm_nb = successive_approx_numba(ifp_numba, qe.toc() ``` - Now let's check the outputs in a plot to make sure they are the same. ```{code-cell} ipython3 @@ -595,7 +591,6 @@ plt.legend() plt.show() ``` - ### Timing Now let's compare execution time of the two methods @@ -618,16 +613,7 @@ numba_time = qe.toc() jax_time / numba_time ``` - The JAX code is significantly faster, as expected. This difference will increase when more features (and state variables) are added to the model. - -```{code-cell} ipython3 - -``` - -```{code-cell} ipython3 - -```