From 91efc046c7583108c17c928c143609f03cd93c95 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Fri, 15 Mar 2024 07:13:53 +1100 Subject: [PATCH] misc (#153) --- lectures/opt_savings_1.md | 58 +++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/lectures/opt_savings_1.md b/lectures/opt_savings_1.md index d0f5bad9..90ae2303 100644 --- a/lectures/opt_savings_1.md +++ b/lectures/opt_savings_1.md @@ -56,34 +56,52 @@ $$ W_{t+1} + C_t \leq R W_t + Y_t $$ -We assume that labor income $(Y_t)$ is a discretized AR(1) process. +where + +* $C_t$ is consumption and $C_t \geq 0$, +* $W_t$ is wealth and $W_t \geq 0$, +* $R > 0$ is a gross rate of return, and +* $(Y_t)$ is labor income. + +We assume below that labor income is a discretized AR(1) process. -The right-hand side of the Bellman equation is +The Bellman equation is $$ -B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). + v(w) = \max_{0 \leq w' \leq Rw + y} + \left\{ + u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y') + \right\} $$ where $$ -u(c) = \frac{c^{1-\gamma}}{1-\gamma} + u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$ -## Starting with NumPy +In the code we use the function + +$$ + B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). +$$ + +the encapsulate the right hand side of the Bellman equation. + -Let's start with a standard NumPy version, running on the CPU. -This is a traditional approach using relatively old technologies. +## Starting with NumPy + +Let's start with a standard NumPy version running on the CPU. -Starting with NumPy will allow us to record the speed gain associated with switching to JAX. +Starting with this traditional approach will allow us to record the speed gain +associated with switching to JAX. (NumPy operations are similar to MATLAB operations, so this also serves as a rough comparison with MATLAB.) - ### Functions and operators The following function contains default parameters and returns tuples that @@ -218,6 +236,8 @@ ax.legend() plt.show() ``` + + ## Switching to JAX To switch over to JAX, we change `np` to `jnp` throughout and add some @@ -284,7 +304,6 @@ def B(v, constants, sizes, arrays): return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) -B = jax.jit(B, static_argnums=(2,)) ``` Some readers might be concerned that we are creating high dimensional arrays, @@ -295,6 +314,12 @@ Could they be avoided by more careful vectorization? In fact this is not necessary: this function will be JIT-compiled by JAX, and the JIT compiler will optimize compiled code to minimize memory use. +```{code-cell} ipython3 +B = jax.jit(B, static_argnums=(2,)) +``` + +In the call above, we indicate to the compiler that `sizes` is static, so the +compiler can parallelize optimally while taking array sizes as fixed. The Bellman operator $T$ can be implemented by @@ -505,14 +530,19 @@ print(jnp.allclose(v_star_vmap, v_star_jax)) print(jnp.allclose(σ_star_vmap, σ_star_jax)) ``` -Here's how long the `vmap` code takes relative to the first JAX implementation -(which used direct vectorization). +Here's the speed gain associated with switching from the NumPy version to JAX with `vmap`: + +```{code-cell} ipython3 +print(f"Relative speed = {numpy_elapsed / jax_vmap_elapsed}") +``` + +And here's the comparison with the first JAX implementation (which used direct vectorization). ```{code-cell} ipython3 -print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}") +print(f"Relative speed = {jax_elapsed / jax_vmap_elapsed}") ``` -The execution times are relatively similar. +The execution times for the two JAX versions are relatively similar. However, as emphasized above, having a second method up our sleeves (i.e, the `vmap` approach) will be helpful when confronting dynamic programs with more