Skip to content

Commit

Permalink
Use vmap throughout opt savings 2 (#155)
Browse files Browse the repository at this point in the history
* misc

* misc

* misc
  • Loading branch information
jstac authored Mar 15, 2024
1 parent ce4ef38 commit ef43c79
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 79 deletions.
6 changes: 5 additions & 1 deletion lectures/opt_savings_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,15 @@ def create_consumption_model(R=1.01, # Gross interest rate
A function that takes in parameters and returns parameters and grids
for the optimal savings problem.
"""
# Build grids and transition probabilities
w_grid = np.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = np.exp(mc.state_values), mc.P
# Pack and return
params = β, R, γ
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
arrays = w_grid, y_grid, Q
return params, sizes, arrays
```

(The function returns sizes of arrays because we use them later to help
Expand Down
185 changes: 107 additions & 78 deletions lectures/opt_savings_2.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,121 +100,155 @@ def create_consumption_model(R=1.01, # Gross interest rate
A function that takes in parameters and returns parameters and grids
for the optimal savings problem.
"""
# Build grids and transition probabilities
w_grid = jnp.linspace(w_min, w_max, w_size)
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P)
y_grid, Q = jnp.exp(mc.state_values), mc.P
# Pack and return
params = β, R, γ
sizes = w_size, y_size
return (β, R, γ), sizes, (w_grid, y_grid, Q)
arrays = w_grid, y_grid, jnp.array(Q)
return params, sizes, arrays
```

Here's the right hand side of the Bellman equation:

```{code-cell} ipython3
def B(v, params, sizes, arrays):
def _B(v, params, arrays, i, j, ip):
"""
A vectorized version of the right-hand side of the Bellman equation
(before maximization), which is a 3D array representing
The right-hand side of the Bellman equation before maximization, which takes
the form
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
for all (w, y, w′).
The indices are (i, j, ip) -> (w, y, w′).
"""
# Unpack
β, R, γ = params
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute current rewards r(w, y, wp) as array r[i, j, ip]
w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]
w, y, wp = w_grid[i], y_grid[j], w_grid[ip]
c = R * w + y - wp
EV = jnp.sum(v[ip, :] * Q[j, :])
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
```

# Calculate continuation rewards at all combinations of (w, y, wp)
v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
Now we successively apply `vmap` to vectorize $B$ by simulating nested loops.

# Compute the right-hand side of the Bellman equation
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
```{code-cell} ipython3
B_1 = jax.vmap(_B, in_axes=(None, None, None, None, None, 0))
B_2 = jax.vmap(B_1, in_axes=(None, None, None, None, 0, None))
B_vmap = jax.vmap(B_2, in_axes=(None, None, None, 0, None, None))
```

Here's a fully vectorized version of $B$.

```{code-cell} ipython3
def B(v, params, sizes, arrays):
w_size, y_size = sizes
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
return B_vmap(v, params, arrays, w_indices, y_indices, w_indices)
B = jax.jit(B, static_argnums=(2,))
```

## Operators


Here's the Bellman operator $T$

```{code-cell} ipython3
def T(v, params, sizes, arrays):
"The Bellman operator."
return jnp.max(B(v, params, sizes, arrays), axis=-1)
T = jax.jit(T, static_argnums=(2,))
```

The next function computes a $v$-greedy policy given $v$

```{code-cell} ipython3
def get_greedy(v, params, sizes, arrays):
"Computes a v-greedy policy, returned as a set of indices."
return jnp.argmax(B(v, params, sizes, arrays), axis=-1)
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
```

We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$,
which is defined as the vector

$$
r_\sigma(w, y) := r(w, y, \sigma(w, y))
r_\sigma(w, y) := r(w, y, \sigma(w, y))
$$

```{code-cell} ipython3
def compute_r_σ(σ, params, sizes, arrays):
def _compute_r_σ(σ, params, arrays, i, j):
"""
Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current
rewards given policy σ.
With indices (i, j) -> (w, y) and wp = σ[i, j], compute
r_σ[i, j] = u(Rw + y - wp)
which gives current rewards under policy σ.
"""
# Unpack model
β, R, γ = params
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute r_σ[i, j]
w = jnp.reshape(w_grid, (w_size, 1))
y = jnp.reshape(y_grid, (1, y_size))
wp = w_grid[σ]
w, y, wp = w_grid[i], y_grid[j], w_grid[σ[i, j]]
c = R * w + y - wp
r_σ = c**(1-γ)/(1-γ)
return r_σ
```

Now we define the policy operator $T_\sigma$
Now we successively apply `vmap` to simulate nested loops.

```{code-cell} ipython3
def T_σ(v, σ, params, sizes, arrays):
"The σ-policy operator."
r_1 = jax.vmap(_compute_r_σ, in_axes=(None, None, None, None, 0))
r_σ_vmap = jax.vmap(r_1, in_axes=(None, None, None, 0, None))
```

# Unpack model
β, R, γ = params
Here's a fully vectorized version of $r_\sigma$.

```{code-cell} ipython3
def compute_r_σ(σ, params, sizes, arrays):
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
return r_σ_vmap(σ, params, arrays, w_indices, y_indices)
r_σ = compute_r_σ(σ, params, sizes, arrays)
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
```

Now we define the policy operator $T_\sigma$ going through similar steps

# Compute the array v[σ[i, j], jp]
yp_idx = jnp.arange(y_size)
yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
σ = jnp.reshape(σ, (w_size, y_size, 1))
V = v[σ, yp_idx]
```{code-cell} ipython3
def _T_σ(v, σ, params, arrays, i, j):
"The σ-policy operator."
# Convert Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, y_size, y_size))
# Unpack model
β, R, γ = params
w_grid, y_grid, Q = arrays
r_σ = _compute_r_σ(σ, params, arrays, i, j)
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
EV = jnp.sum(V * Q, axis=2)
EV = jnp.sum(v[σ[i, j], :] * Q[j, :])
return r_σ + β * EV
```
and the Bellman operator $T$
```{code-cell} ipython3
def T(v, params, sizes, arrays):
"The Bellman operator."
return jnp.max(B(v, params, sizes, arrays), axis=2)
```
T_1 = jax.vmap(_T_σ, in_axes=(None, None, None, None, None, 0))
T_σ_vmap = jax.vmap(T_1, in_axes=(None, None, None, None, 0, None))
The next function computes a $v$-greedy policy given $v$
def T_σ(v, σ, params, sizes, arrays):
w_size, y_size = sizes
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
return T_σ_vmap(v, σ, params, arrays, w_indices, y_indices)
```{code-cell} ipython3
def get_greedy(v, params, sizes, arrays):
"Computes a v-greedy policy, returned as a set of indices."
return jnp.argmax(B(v, params, sizes, arrays), axis=2)
T_σ = jax.jit(T_σ, static_argnums=(3,))
```


The function below computes the value $v_\sigma$ of following policy $\sigma$.

This lifetime value is a function $v_\sigma$ that satisfies
Expand Down Expand Up @@ -248,29 +282,28 @@ JAX allows us to solve linear systems defined in terms of operators; the first
step is to define the function $L_{\sigma}$.

```{code-cell} ipython3
def L_σ(v, σ, params, sizes, arrays):
def _L_σ(v, σ, params, arrays, i, j):
"""
Here we set up the linear map v -> L_σ v, where
(L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
"""
# Unpack
β, R, γ = params
w_size, y_size = sizes
w_grid, y_grid, Q = arrays
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
return v[i, j] - β * jnp.sum(v[σ[i, j], :] * Q[j, :])
# Set up the array v[σ[i, j], jp]
zp_idx = jnp.arange(y_size)
zp_idx = jnp.reshape(zp_idx, (1, 1, y_size))
σ = jnp.reshape(σ, (w_size, y_size, 1))
V = v[σ, zp_idx]
L_1 = jax.vmap(_L_σ, in_axes=(None, None, None, None, None, 0))
L_σ_vmap = jax.vmap(L_1, in_axes=(None, None, None, None, 0, None))
# Expand Q[j, jp] to Q[i, j, jp]
Q = jnp.reshape(Q, (1, y_size, y_size))
def L_σ(v, σ, params, sizes, arrays):
w_size, y_size = sizes
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
return L_σ_vmap(v, σ, params, arrays, w_indices, y_indices)
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
return v - β * jnp.sum(V * Q, axis=2)
L_σ = jax.jit(L_σ, static_argnums=(3,))
```

Now we can define a function to compute $v_{\sigma}$
Expand All @@ -290,20 +323,16 @@ def get_value(σ, params, sizes, arrays):
partial_L_σ = lambda v: L_σ(v, σ, params, sizes, arrays)
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
```

## JIT compiled versions
```{code-cell} ipython3
B = jax.jit(B, static_argnums=(2,))
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
T = jax.jit(T, static_argnums=(2,))
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
get_value = jax.jit(get_value, static_argnums=(2,))
T_σ = jax.jit(T_σ, static_argnums=(3,))
L_σ = jax.jit(L_σ, static_argnums=(3,))
```



## Iteration


We use successive approximation for VFI.

```{code-cell} ipython3
Expand Down

0 comments on commit ef43c79

Please sign in to comment.