-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
386 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
--- | ||
jupytext: | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.16.1 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
# Optimal Savings | ||
|
||
```{include} _admonition/gpu.md | ||
``` | ||
|
||
In addition to what’s in Anaconda, this lecture will need the following libraries: | ||
|
||
```{code-cell} ipython3 | ||
:tags: [hide-output] | ||
!pip install quantecon | ||
``` | ||
|
||
We will use the following imports: | ||
|
||
```{code-cell} ipython3 | ||
import quantecon as qe | ||
import numpy as np | ||
import jax | ||
import jax.numpy as jnp | ||
from collections import namedtuple | ||
import matplotlib.pyplot as plt | ||
import time | ||
``` | ||
|
||
Let's check the GPU we are running | ||
|
||
```{code-cell} ipython3 | ||
!nvidia-smi | ||
``` | ||
|
||
We'll use 64 bit floats to gain extra precision. | ||
|
||
```{code-cell} ipython3 | ||
jax.config.update("jax_enable_x64", True) | ||
``` | ||
|
||
## Overview | ||
|
||
We consider an optimal savings problem with CRRA utility and budget constraint | ||
|
||
$$ W_{t+1} + C_t \leq R W_t + Y_t $$ | ||
|
||
We assume that labor income $(Y_t)$ is a discretized AR(1) process. | ||
|
||
The right-hand side of the Bellman equation is | ||
|
||
$$ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). $$ | ||
|
||
where | ||
|
||
$$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$ | ||
|
||
|
||
The following function contains default parameters and returns tuples that | ||
contain the key computational components of the model. | ||
|
||
|
||
```{code-cell} ipython3 | ||
def create_consumption_model(R=1.01, # Gross interest rate | ||
β=0.98, # Discount factor | ||
γ=2, # CRRA parameter | ||
w_min=0.01, # Min wealth | ||
w_max=5.0, # Max wealth | ||
w_size=150, # Grid side | ||
ρ=0.9, ν=0.1, y_size=100): # Income parameters | ||
""" | ||
A function that takes in parameters and returns parameters and grids | ||
for the optimal savings problem. | ||
""" | ||
w_grid = jnp.linspace(w_min, w_max, w_size) | ||
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) | ||
y_grid, Q = jnp.exp(mc.state_values), mc.P | ||
β, R, γ = jax.device_put([β, R, γ]) | ||
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q])) | ||
sizes = w_size, y_size | ||
return (β, R, γ), sizes, (w_grid, y_grid, Q) | ||
``` | ||
|
||
The function above returns sizes of arrays because we need them to effectively | ||
compile all of the functions below. | ||
|
||
|
||
|
||
## A solution using vectorization | ||
|
||
We will start with a fully vectorized solution, in the sense that the right hand | ||
side of the Bellman equation is represented as a multi-dimensional array with | ||
dimensions over all states and controls. | ||
|
||
(Later we will examine an alternative method that uses `vmap`.) | ||
|
||
Here's the right hand side of the Bellman equation as a vectorized expression: | ||
|
||
```{code-cell} ipython3 | ||
def B(v, constants, sizes, arrays): | ||
""" | ||
A vectorized version of the right-hand side of the Bellman equation | ||
(before maximization), which is a 3D array representing | ||
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′) | ||
for all (w, y, w′). | ||
""" | ||
# Unpack | ||
β, R, γ = constants | ||
w_size, y_size = sizes | ||
w_grid, y_grid, Q = arrays | ||
# Compute current rewards r(w, y, wp) as array r[i, j, ip] | ||
w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip] | ||
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip] | ||
wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip] | ||
c = R * w + y - wp | ||
# Calculate continuation rewards at all combinations of (w, y, wp) | ||
v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp] | ||
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp] | ||
EV = jnp.sum(v * Q, axis=3) # sum over last index jp | ||
# Compute the right-hand side of the Bellman equation | ||
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) | ||
B = jax.jit(B, static_argnums=(2,)) | ||
``` | ||
|
||
Readers familiar with MATLAB might be concerned that we are creating high | ||
dimensional arrays, leading to inefficiency. | ||
|
||
Could they be avoided by more careful parallelization? | ||
|
||
In fact this is not necessary: this function will be JIT-compiled by JAX, and | ||
the JIT compiler will optimize execution to minimize memory use. | ||
|
||
|
||
|
||
### The Bellman operator | ||
|
||
The Bellman operator $T$ can be implemented by | ||
|
||
```{code-cell} ipython3 | ||
def T(v, constants, sizes, arrays): | ||
"The Bellman operator." | ||
return jnp.max(B(v, constants, sizes, arrays), axis=2) | ||
T = jax.jit(T, static_argnums=(2,)) | ||
``` | ||
|
||
The next function computes a $v$-greedy policy given $v$ (i.e., the policy that | ||
maximizes the right-hand side of the Bellman equation.) | ||
|
||
```{code-cell} ipython3 | ||
def get_greedy(v, constants, sizes, arrays): | ||
"Computes a v-greedy policy, returned as a set of indices." | ||
return jnp.argmax(B(v, constants, sizes, arrays), axis=2) | ||
get_greedy = jax.jit(get_greedy, static_argnums=(2,)) | ||
``` | ||
|
||
|
||
|
||
### Successive approximation | ||
|
||
Now we define a solver that implements VFI. | ||
|
||
```{code-cell} ipython3 | ||
def value_iteration(model, tol=1e-5): | ||
constants, sizes, arrays = model | ||
vz = jnp.zeros(sizes) | ||
_T = lambda v: T(v, constants, sizes, arrays) | ||
v_star = successive_approx_jax(_T, vz, tolerance=tol) | ||
return v_star, get_greedy(v_star, constants, sizes, arrays) | ||
``` | ||
|
||
Let's create an instance and unpack it. | ||
|
||
```{code-cell} ipython3 | ||
fontsize = 12 | ||
model = create_consumption_model() | ||
# Unpack | ||
constants, sizes, arrays = model | ||
β, R, γ = constants | ||
w_size, y_size = sizes | ||
w_grid, y_grid, Q = arrays | ||
``` | ||
|
||
Let's see how long it takes to solve this model. | ||
|
||
```{code-cell} ipython3 | ||
print("Starting VFI.") | ||
start_time = time.time() | ||
v_star, σ_star = value_iteration(model) | ||
jax_elapsed = time.time() - start_time | ||
print(f"VFI completed in {jax_elapsed} seconds.") | ||
``` | ||
|
||
Let's do it once more to eliminate compile time: | ||
|
||
|
||
```{code-cell} ipython3 | ||
v_star, σ_star = value_iteration(model) | ||
jax_elapsed = time.time() - start_time | ||
print(f"VFI completed in {jax_elapsed} seconds.") | ||
``` | ||
|
||
Here's a plot of the policy function. | ||
|
||
|
||
```{code-cell} ipython3 | ||
fig, ax = plt.subplots(figsize=(9, 5.2)) | ||
ax.plot(w_grid, w_grid, "k--", label="45") | ||
ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$") | ||
ax.plot(w_grid, w_grid[σ_star[:, -1]], label="$\\sigma^*(\cdot, y_N)$") | ||
ax.legend(fontsize=fontsize) | ||
plt.show() | ||
``` | ||
|
||
|
||
|
||
|
||
## Comparison with NumPy | ||
|
||
How much does JAX improve upon a more basic version using NumPy on the CPU? | ||
|
||
(NumPy operations are similar to MATLAB operations, so this also serves as a | ||
rough comparison with MATLAB.) | ||
|
||
To answer this question, we simply change `jnp` to `np` and remove the `jax.jit` | ||
requests. | ||
|
||
```{code-cell} ipython3 | ||
def create_consumption_model(R=1.01, # Gross interest rate | ||
β=0.98, # Discount factor | ||
γ=2, # CRRA parameter | ||
w_min=0.01, # Min wealth | ||
w_max=5.0, # Max wealth | ||
w_size=150, # Grid side | ||
ρ=0.9, ν=0.1, y_size=100): # Income parameters | ||
""" | ||
A function that takes in parameters and returns parameters and grids | ||
for the optimal savings problem. | ||
""" | ||
w_grid = np.linspace(w_min, w_max, w_size) | ||
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) | ||
y_grid, Q = np.exp(mc.state_values), mc.P | ||
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q])) | ||
sizes = w_size, y_size | ||
return (β, R, γ), sizes, (w_grid, y_grid, Q) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
def B(v, constants, sizes, arrays): | ||
""" | ||
A vectorized version of the right-hand side of the Bellman equation | ||
(before maximization), which is a 3D array representing | ||
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′) | ||
for all (w, y, w′). | ||
""" | ||
# Unpack | ||
β, R, γ = constants | ||
w_size, y_size = sizes | ||
w_grid, y_grid, Q = arrays | ||
# Compute current rewards r(w, y, wp) as array r[i, j, ip] | ||
w = np.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip] | ||
y = np.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip] | ||
wp = np.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip] | ||
c = R * w + y - wp | ||
# Calculate continuation rewards at all combinations of (w, y, wp) | ||
v = np.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp] | ||
Q = np.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp] | ||
EV = np.sum(v * Q, axis=3) # sum over last index jp | ||
# Compute the right-hand side of the Bellman equation | ||
return np.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -np.inf) | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
def T(v, constants, sizes, arrays): | ||
"The Bellman operator." | ||
return np.max(B(v, constants, sizes, arrays), axis=2) | ||
def get_greedy(v, constants, sizes, arrays): | ||
"Computes a v-greedy policy, returned as a set of indices." | ||
return np.argmax(B(v, constants, sizes, arrays), axis=2) | ||
``` | ||
|
||
Here's the solver. | ||
|
||
```{code-cell} ipython3 | ||
def value_iteration(model, tol=1e-5): | ||
constants, sizes, arrays = model | ||
vz = np.zeros(sizes) | ||
_T = lambda v: T(v, constants, sizes, arrays) | ||
v_star = successive_approx_jax(_T, vz, tolerance=tol) | ||
return v_star, get_greedy(v_star, constants, sizes, arrays) | ||
``` | ||
|
||
Now we create an instance, unpack it, and test how long it takes to solve the | ||
model. | ||
|
||
```{code-cell} ipython3 | ||
fontsize = 12 | ||
model = create_consumption_model() | ||
# Unpack | ||
constants, sizes, arrays = model | ||
β, R, γ = constants | ||
w_size, y_size = sizes | ||
w_grid, y_grid, Q = arrays | ||
print("Starting VFI.") | ||
start_time = time.time() | ||
v_star, σ_star = value_iteration(model) | ||
numpy_elapsed = time.time() - start_time | ||
print(f"VFI completed in {numpy_elapsed} seconds.") | ||
``` | ||
|
||
|
||
## Switching to vmap | ||
|
||
|
Oops, something went wrong.