diff --git a/lectures/opt_savings_1.md b/lectures/opt_savings_1.md index 960acdd4..d0f5bad9 100644 --- a/lectures/opt_savings_1.md +++ b/lectures/opt_savings_1.md @@ -76,15 +76,14 @@ Let's start with a standard NumPy version, running on the CPU. This is a traditional approach using relatively old technologies. -One reason we start with NumPy is that switching from NumPy to JAX will be -relatively trivial. - -The other reason is that we want to know the speed gain associated with -switching to JAX. +Starting with NumPy will allow us to record the speed gain associated with switching to JAX. (NumPy operations are similar to MATLAB operations, so this also serves as a rough comparison with MATLAB.) + + + ### Functions and operators The following function contains default parameters and returns tuples that @@ -106,7 +105,6 @@ def create_consumption_model(R=1.01, # Gross interest rate w_grid = np.linspace(w_min, w_max, w_size) mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) y_grid, Q = np.exp(mc.state_values), mc.P - w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q])) sizes = w_size, y_size return (β, R, γ), sizes, (w_grid, y_grid, Q) ``` @@ -397,9 +395,19 @@ The relative speed gain is print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}") ``` + +This is an impressive speed up and in fact we can do better still by switching +to alternative algorithms that are better suited to parallelization. + +These algorithms are discussed in a {doc}`separate lecture `. + + ## Switching to vmap -For this simple optimal savings problem direct vectorization is relatively easy. +Before we discuss alternative algorithms, let's take another look at value +function iteration. + +For this simple optimal savings problem, direct vectorization is relatively easy. In particular, it's straightforward to express the right hand side of the Bellman equation as an array that stores evaluations of the function at every @@ -497,8 +505,15 @@ print(jnp.allclose(v_star_vmap, v_star_jax)) print(jnp.allclose(σ_star_vmap, σ_star_jax)) ``` -The relative speed is +Here's how long the `vmap` code takes relative to the first JAX implementation +(which used direct vectorization). ```{code-cell} ipython3 print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}") ``` + +The execution times are relatively similar. + +However, as emphasized above, having a second method up our sleeves (i.e, the +`vmap` approach) will be helpful when confronting dynamic programs with more +sophisticated Bellman equations. diff --git a/lectures/opt_savings_2.md b/lectures/opt_savings_2.md index 6e3dfaff..a05fbe10 100644 --- a/lectures/opt_savings_2.md +++ b/lectures/opt_savings_2.md @@ -110,8 +110,6 @@ def create_consumption_model(R=1.01, # Gross interest rate Here's the right hand side of the Bellman equation: ```{code-cell} ipython3 -:tags: [hide-input] - def B(v, constants, sizes, arrays): """ A vectorized version of the right-hand side of the Bellman equation @@ -427,4 +425,4 @@ ax.legend(frameon=False) ax.set_xlabel("$m$") ax.set_ylabel("time") plt.show() -``` \ No newline at end of file +```