Skip to content

Commit

Permalink
misc (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
jstac authored Mar 14, 2024
1 parent 0e3aa57 commit 91efc04
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions lectures/opt_savings_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,52 @@ $$
W_{t+1} + C_t \leq R W_t + Y_t
$$

We assume that labor income $(Y_t)$ is a discretized AR(1) process.
where

* $C_t$ is consumption and $C_t \geq 0$,
* $W_t$ is wealth and $W_t \geq 0$,
* $R > 0$ is a gross rate of return, and
* $(Y_t)$ is labor income.

We assume below that labor income is a discretized AR(1) process.

The right-hand side of the Bellman equation is
The Bellman equation is

$$
B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
v(w) = \max_{0 \leq w' \leq Rw + y}
\left\{
u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y')
\right\}
$$

where

$$
u(c) = \frac{c^{1-\gamma}}{1-\gamma}
u(c) = \frac{c^{1-\gamma}}{1-\gamma}
$$

## Starting with NumPy
In the code we use the function

$$
B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y').
$$

the encapsulate the right hand side of the Bellman equation.


Let's start with a standard NumPy version, running on the CPU.

This is a traditional approach using relatively old technologies.
## Starting with NumPy

Let's start with a standard NumPy version running on the CPU.

Starting with NumPy will allow us to record the speed gain associated with switching to JAX.
Starting with this traditional approach will allow us to record the speed gain
associated with switching to JAX.

(NumPy operations are similar to MATLAB operations, so this also serves as a
rough comparison with MATLAB.)




### Functions and operators

The following function contains default parameters and returns tuples that
Expand Down Expand Up @@ -218,6 +236,8 @@ ax.legend()
plt.show()
```



## Switching to JAX

To switch over to JAX, we change `np` to `jnp` throughout and add some
Expand Down Expand Up @@ -284,7 +304,6 @@ def B(v, constants, sizes, arrays):
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
B = jax.jit(B, static_argnums=(2,))
```

Some readers might be concerned that we are creating high dimensional arrays,
Expand All @@ -295,6 +314,12 @@ Could they be avoided by more careful vectorization?
In fact this is not necessary: this function will be JIT-compiled by JAX, and
the JIT compiler will optimize compiled code to minimize memory use.

```{code-cell} ipython3
B = jax.jit(B, static_argnums=(2,))
```

In the call above, we indicate to the compiler that `sizes` is static, so the
compiler can parallelize optimally while taking array sizes as fixed.

The Bellman operator $T$ can be implemented by

Expand Down Expand Up @@ -505,14 +530,19 @@ print(jnp.allclose(v_star_vmap, v_star_jax))
print(jnp.allclose(σ_star_vmap, σ_star_jax))
```

Here's how long the `vmap` code takes relative to the first JAX implementation
(which used direct vectorization).
Here's the speed gain associated with switching from the NumPy version to JAX with `vmap`:

```{code-cell} ipython3
print(f"Relative speed = {numpy_elapsed / jax_vmap_elapsed}")
```

And here's the comparison with the first JAX implementation (which used direct vectorization).

```{code-cell} ipython3
print(f"Relative speed = {jax_vmap_elapsed / jax_elapsed}")
print(f"Relative speed = {jax_elapsed / jax_vmap_elapsed}")
```

The execution times are relatively similar.
The execution times for the two JAX versions are relatively similar.

However, as emphasized above, having a second method up our sleeves (i.e, the
`vmap` approach) will be helpful when confronting dynamic programs with more
Expand Down

0 comments on commit 91efc04

Please sign in to comment.