Skip to content

Commit

Permalink
Split optimal savings
Browse files Browse the repository at this point in the history
  • Loading branch information
jstac committed Mar 13, 2024
1 parent 97cbc37 commit c78e503
Show file tree
Hide file tree
Showing 3 changed files with 386 additions and 37 deletions.
3 changes: 2 additions & 1 deletion lectures/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ parts:
- caption: Dynamic Programming
numbered: true
chapters:
- file: opt_savings_1
- file: opt_savings_2
- file: short_path
- file: opt_invest
- file: opt_savings
# - file: inventory_ssd
- file: ifp_egm
- file: arellano
Expand Down
342 changes: 342 additions & 0 deletions lectures/opt_savings_1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
---
jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Optimal Savings

```{include} _admonition/gpu.md
```

In addition to what’s in Anaconda, this lecture will need the following libraries:

```{code-cell} ipython3
:tags: [hide-output]
!pip install quantecon
```

We will use the following imports:

```{code-cell} ipython3
import quantecon as qe
import numpy as np
import jax
import jax.numpy as jnp
from collections import namedtuple
import matplotlib.pyplot as plt
import time
```

Let's check the GPU we are running

```{code-cell} ipython3
!nvidia-smi
```

We'll use 64 bit floats to gain extra precision.

```{code-cell} ipython3
jax.config.update("jax_enable_x64", True)
```

## Overview

We consider an optimal savings problem with CRRA utility and budget constraint

$$ W_{t+1} + C_t \leq R W_t + Y_t $$

We assume that labor income $(Y_t)$ is a discretized AR(1) process.

The right-hand side of the Bellman equation is

$$ B((w, y), w', v) = u(Rw + y - w') + β \sum_{y'} v(w', y') Q(y, y'). $$

where

$$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$


The following function contains default parameters and returns tuples that
contain the key computational components of the model.


```{code-cell} ipython3
def create_consumption_model(R=1.01, # Gross interest rate
β=0.98, # Discount factor
γ=2, # CRRA parameter
w_min=0.01, # Min wealth
w_max=5.0, # Max wealth
w_size=150, # Grid side
ρ=0.9, ν=0.1, y_size=100): # Income parameters
"""
A function that takes in parameters and returns parameters and grids
for the optimal savings problem.
"""
w_grid = jnp.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = jnp.exp(mc.state_values), mc.P
β, R, γ = jax.device_put([β, R, γ])
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
```

The function above returns sizes of arrays because we need them to effectively
compile all of the functions below.



## A solution using vectorization

We will start with a fully vectorized solution, in the sense that the right hand
side of the Bellman equation is represented as a multi-dimensional array with
dimensions over all states and controls.

(Later we will examine an alternative method that uses `vmap`.)

Here's the right hand side of the Bellman equation as a vectorized expression:

```{code-cell} ipython3
def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
(before maximization), which is a 3D array representing
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
for all (w, y, w′).
"""
# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute current rewards r(w, y, wp) as array r[i, j, ip]
w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]
c = R * w + y - wp
# Calculate continuation rewards at all combinations of (w, y, wp)
v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
# Compute the right-hand side of the Bellman equation
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
B = jax.jit(B, static_argnums=(2,))
```

Readers familiar with MATLAB might be concerned that we are creating high
dimensional arrays, leading to inefficiency.

Could they be avoided by more careful parallelization?

In fact this is not necessary: this function will be JIT-compiled by JAX, and
the JIT compiler will optimize execution to minimize memory use.



### The Bellman operator

The Bellman operator $T$ can be implemented by

```{code-cell} ipython3
def T(v, constants, sizes, arrays):
"The Bellman operator."
return jnp.max(B(v, constants, sizes, arrays), axis=2)
T = jax.jit(T, static_argnums=(2,))
```

The next function computes a $v$-greedy policy given $v$ (i.e., the policy that
maximizes the right-hand side of the Bellman equation.)

```{code-cell} ipython3
def get_greedy(v, constants, sizes, arrays):
"Computes a v-greedy policy, returned as a set of indices."
return jnp.argmax(B(v, constants, sizes, arrays), axis=2)
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
```



### Successive approximation

Now we define a solver that implements VFI.

```{code-cell} ipython3
def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
vz = jnp.zeros(sizes)
_T = lambda v: T(v, constants, sizes, arrays)
v_star = successive_approx_jax(_T, vz, tolerance=tol)
return v_star, get_greedy(v_star, constants, sizes, arrays)
```

Let's create an instance and unpack it.

```{code-cell} ipython3
fontsize = 12
model = create_consumption_model()
# Unpack
constants, sizes, arrays = model
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
```

Let's see how long it takes to solve this model.

```{code-cell} ipython3
print("Starting VFI.")
start_time = time.time()
v_star, σ_star = value_iteration(model)
jax_elapsed = time.time() - start_time
print(f"VFI completed in {jax_elapsed} seconds.")
```

Let's do it once more to eliminate compile time:


```{code-cell} ipython3
v_star, σ_star = value_iteration(model)
jax_elapsed = time.time() - start_time
print(f"VFI completed in {jax_elapsed} seconds.")
```

Here's a plot of the policy function.


```{code-cell} ipython3
fig, ax = plt.subplots(figsize=(9, 5.2))
ax.plot(w_grid, w_grid, "k--", label="45")
ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$")
ax.plot(w_grid, w_grid[σ_star[:, -1]], label="$\\sigma^*(\cdot, y_N)$")
ax.legend(fontsize=fontsize)
plt.show()
```




## Comparison with NumPy

How much does JAX improve upon a more basic version using NumPy on the CPU?

(NumPy operations are similar to MATLAB operations, so this also serves as a
rough comparison with MATLAB.)

To answer this question, we simply change `jnp` to `np` and remove the `jax.jit`
requests.

```{code-cell} ipython3
def create_consumption_model(R=1.01, # Gross interest rate
β=0.98, # Discount factor
γ=2, # CRRA parameter
w_min=0.01, # Min wealth
w_max=5.0, # Max wealth
w_size=150, # Grid side
ρ=0.9, ν=0.1, y_size=100): # Income parameters
"""
A function that takes in parameters and returns parameters and grids
for the optimal savings problem.
"""
w_grid = np.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = np.exp(mc.state_values), mc.P
w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
```

```{code-cell} ipython3
def B(v, constants, sizes, arrays):
"""
A vectorized version of the right-hand side of the Bellman equation
(before maximization), which is a 3D array representing
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
for all (w, y, w′).
"""
# Unpack
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute current rewards r(w, y, wp) as array r[i, j, ip]
w = np.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]
y = np.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
wp = np.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]
c = R * w + y - wp
# Calculate continuation rewards at all combinations of (w, y, wp)
v = np.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
Q = np.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
EV = np.sum(v * Q, axis=3) # sum over last index jp
# Compute the right-hand side of the Bellman equation
return np.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -np.inf)
```

```{code-cell} ipython3
def T(v, constants, sizes, arrays):
"The Bellman operator."
return np.max(B(v, constants, sizes, arrays), axis=2)
def get_greedy(v, constants, sizes, arrays):
"Computes a v-greedy policy, returned as a set of indices."
return np.argmax(B(v, constants, sizes, arrays), axis=2)
```

Here's the solver.

```{code-cell} ipython3
def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
vz = np.zeros(sizes)
_T = lambda v: T(v, constants, sizes, arrays)
v_star = successive_approx_jax(_T, vz, tolerance=tol)
return v_star, get_greedy(v_star, constants, sizes, arrays)
```

Now we create an instance, unpack it, and test how long it takes to solve the
model.

```{code-cell} ipython3
fontsize = 12
model = create_consumption_model()
# Unpack
constants, sizes, arrays = model
β, R, γ = constants
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
print("Starting VFI.")
start_time = time.time()
v_star, σ_star = value_iteration(model)
numpy_elapsed = time.time() - start_time
print(f"VFI completed in {numpy_elapsed} seconds.")
```


## Switching to vmap


Loading

0 comments on commit c78e503

Please sign in to comment.