Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small patches to optimal savings #152

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions lectures/opt_savings_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ Let's start with a standard NumPy version, running on the CPU.

This is a traditional approach using relatively old technologies.

One reason we start with NumPy is that switching from NumPy to JAX will be
relatively trivial.

The other reason is that we want to know the speed gain associated with
switching to JAX.
Starting with NumPy will allow us to record the speed gain associated with switching to JAX.

(NumPy operations are similar to MATLAB operations, so this also serves as a
rough comparison with MATLAB.)




### Functions and operators

The following function contains default parameters and returns tuples that
Expand All @@ -106,7 +105,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
w_grid = np.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = np.exp(mc.state_values), mc.P
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
```
Expand Down Expand Up @@ -397,9 +395,19 @@ The relative speed gain is
print(f"Relative speed gain = {numpy_elapsed / jax_elapsed}")
```


This is an impressive speed up and in fact we can do better still by switching
to alternative algorithms that are better suited to parallelization.

These algorithms are discussed in a {doc}`separate lecture <opt_savings_2>`.


## Switching to vmap

For this simple optimal savings problem direct vectorization is relatively easy.
Before we discuss alternative algorithms, let's take another look at value
function iteration.

For this simple optimal savings problem, direct vectorization is relatively easy.

In particular, it's straightforward to express the right hand side of the
Bellman equation as an array that stores evaluations of the function at every
Expand Down Expand Up @@ -497,8 +505,15 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
print(jnp.allclose(σ_star_vmap, σ_star_jax))
```

The relative speed is
Here's how long the `vmap` code takes relative to the first JAX implementation
(which used direct vectorization).

```{code-cell} ipython3
print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
```

The execution times are relatively similar.

However, as emphasized above, having a second method up our sleeves (i.e, the
`vmap` approach) will be helpful when confronting dynamic programs with more
sophisticated Bellman equations.
4 changes: 1 addition & 3 deletions lectures/opt_savings_2.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ def create_consumption_model(R=1.01, # Gross interest rate
Here's the right hand side of the Bellman equation:

```{code-cell} ipython3
:tags: [hide-input]

def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
Expand Down Expand Up @@ -427,4 +425,4 @@ ax.legend(frameon=False)
ax.set_xlabel("$m$")
ax.set_ylabel("time")
plt.show()
```
```
Loading