Skip to content

Commit

Permalink
Small patches to optimal savings (#152)
Browse files Browse the repository at this point in the history
* misc

* misc
  • Loading branch information
jstac authored Mar 14, 2024
1 parent 216ed17 commit 0e3aa57
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
31 changes: 23 additions & 8 deletions lectures/opt_savings_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ Let's start with a standard NumPy version, running on the CPU.

This is a traditional approach using relatively old technologies.

One reason we start with NumPy is that switching from NumPy to JAX will be
relatively trivial.

The other reason is that we want to know the speed gain associated with
switching to JAX.
Starting with NumPy will allow us to record the speed gain associated with switching to JAX.

(NumPy operations are similar to MATLAB operations, so this also serves as a
rough comparison with MATLAB.)




### Functions and operators

The following function contains default parameters and returns tuples that
Expand All @@ -106,7 +105,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
w_grid = np.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = np.exp(mc.state_values), mc.P
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
```
Expand Down Expand Up @@ -397,9 +395,19 @@ The relative speed gain is
print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}")
```


This is an impressive speed up and in fact we can do better still by switching
to alternative algorithms that are better suited to parallelization.

These algorithms are discussed in a {doc}`separate lecture <opt_savings_2>`.


## Switching to vmap

For this simple optimal savings problem direct vectorization is relatively easy.
Before we discuss alternative algorithms, let's take another look at value
function iteration.

For this simple optimal savings problem, direct vectorization is relatively easy.

In particular, it's straightforward to express the right hand side of the
Bellman equation as an array that stores evaluations of the function at every
Expand Down Expand Up @@ -497,8 +505,15 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
print(jnp.allclose(σ_star_vmap, σ_star_jax))
```

The relative speed is
Here's how long the `vmap` code takes relative to the first JAX implementation
(which used direct vectorization).

```{code-cell} ipython3
print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
```

The execution times are relatively similar.

However, as emphasized above, having a second method up our sleeves (i.e, the
`vmap` approach) will be helpful when confronting dynamic programs with more
sophisticated Bellman equations.
4 changes: 1 addition & 3 deletions lectures/opt_savings_2.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
Here's the right hand side of the Bellman equation:

```{code-cell} ipython3
:tags: [hide-input]
def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
Expand Down Expand Up @@ -427,4 +425,4 @@ ax.legend(frameon=False)
ax.set_xlabel("$m$")
ax.set_ylabel("time")
plt.show()
```
```

0 comments on commit 0e3aa57

Please sign in to comment.