From ef43c79b606725b8636af4e670a8b95b7c94f8eb Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sat, 16 Mar 2024 08:01:02 +1100 Subject: [PATCH] Use vmap throughout opt savings 2 (#155) * misc * misc * misc --- lectures/opt_savings_1.md | 6 +- lectures/opt_savings_2.md | 185 ++++++++++++++++++++++---------------- 2 files changed, 112 insertions(+), 79 deletions(-) diff --git a/lectures/opt_savings_1.md b/lectures/opt_savings_1.md index 7856d397..401100de 100644 --- a/lectures/opt_savings_1.md +++ b/lectures/opt_savings_1.md @@ -120,11 +120,15 @@ def create_consumption_model(R=1.01, # Gross interest rate A function that takes in parameters and returns parameters and grids for the optimal savings problem. """ + # Build grids and transition probabilities w_grid = np.linspace(w_min, w_max, w_size) mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) y_grid, Q = np.exp(mc.state_values), mc.P + # Pack and return + params = β, R, γ sizes = w_size, y_size - return (β, R, γ), sizes, (w_grid, y_grid, Q) + arrays = w_grid, y_grid, Q + return params, sizes, arrays ``` (The function returns sizes of arrays because we use them later to help diff --git a/lectures/opt_savings_2.md b/lectures/opt_savings_2.md index f45b163e..0cd7ed61 100644 --- a/lectures/opt_savings_2.md +++ b/lectures/opt_savings_2.md @@ -100,121 +100,155 @@ def create_consumption_model(R=1.01, # Gross interest rate A function that takes in parameters and returns parameters and grids for the optimal savings problem. """ + # Build grids and transition probabilities w_grid = jnp.linspace(w_min, w_max, w_size) mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) - y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P) + y_grid, Q = jnp.exp(mc.state_values), mc.P + # Pack and return + params = β, R, γ sizes = w_size, y_size - return (β, R, γ), sizes, (w_grid, y_grid, Q) + arrays = w_grid, y_grid, jnp.array(Q) + return params, sizes, arrays ``` Here's the right hand side of the Bellman equation: ```{code-cell} ipython3 -def B(v, params, sizes, arrays): +def _B(v, params, arrays, i, j, ip): """ - A vectorized version of the right-hand side of the Bellman equation - (before maximization), which is a 3D array representing + The right-hand side of the Bellman equation before maximization, which takes + the form B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′) - for all (w, y, w′). + The indices are (i, j, ip) -> (w, y, w′). """ - - # Unpack β, R, γ = params - w_size, y_size = sizes w_grid, y_grid, Q = arrays - - # Compute current rewards r(w, y, wp) as array r[i, j, ip] - w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip] - y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip] - wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip] + w, y, wp = w_grid[i], y_grid[j], w_grid[ip] c = R * w + y - wp + EV = jnp.sum(v[ip, :] * Q[j, :]) + return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) +``` - # Calculate continuation rewards at all combinations of (w, y, wp) - v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp] - Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp] - EV = jnp.sum(v * Q, axis=3) # sum over last index jp +Now we successively apply `vmap` to vectorize $B$ by simulating nested loops. - # Compute the right-hand side of the Bellman equation - return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) +```{code-cell} ipython3 +B_1 = jax.vmap(_B, in_axes=(None, None, None, None, None, 0)) +B_2 = jax.vmap(B_1, in_axes=(None, None, None, None, 0, None)) +B_vmap = jax.vmap(B_2, in_axes=(None, None, None, 0, None, None)) +``` + +Here's a fully vectorized version of $B$. + +```{code-cell} ipython3 +def B(v, params, sizes, arrays): + w_size, y_size = sizes + w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) + return B_vmap(v, params, arrays, w_indices, y_indices, w_indices) + +B = jax.jit(B, static_argnums=(2,)) ``` ## Operators + +Here's the Bellman operator $T$ + +```{code-cell} ipython3 +def T(v, params, sizes, arrays): + "The Bellman operator." + return jnp.max(B(v, params, sizes, arrays), axis=-1) + +T = jax.jit(T, static_argnums=(2,)) +``` + +The next function computes a $v$-greedy policy given $v$ + +```{code-cell} ipython3 +def get_greedy(v, params, sizes, arrays): + "Computes a v-greedy policy, returned as a set of indices." + return jnp.argmax(B(v, params, sizes, arrays), axis=-1) + +get_greedy = jax.jit(get_greedy, static_argnums=(2,)) + +``` + We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$, which is defined as the vector $$ -r_\sigma(w, y) := r(w, y, \sigma(w, y)) + r_\sigma(w, y) := r(w, y, \sigma(w, y)) $$ ```{code-cell} ipython3 -def compute_r_σ(σ, params, sizes, arrays): +def _compute_r_σ(σ, params, arrays, i, j): """ - Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current - rewards given policy σ. + With indices (i, j) -> (w, y) and wp = σ[i, j], compute + + r_σ[i, j] = u(Rw + y - wp) + + which gives current rewards under policy σ. """ # Unpack model β, R, γ = params - w_size, y_size = sizes w_grid, y_grid, Q = arrays - # Compute r_σ[i, j] - w = jnp.reshape(w_grid, (w_size, 1)) - y = jnp.reshape(y_grid, (1, y_size)) - wp = w_grid[σ] + w, y, wp = w_grid[i], y_grid[j], w_grid[σ[i, j]] c = R * w + y - wp r_σ = c**(1-γ)/(1-γ) return r_σ ``` -Now we define the policy operator $T_\sigma$ +Now we successively apply `vmap` to simulate nested loops. ```{code-cell} ipython3 -def T_σ(v, σ, params, sizes, arrays): - "The σ-policy operator." +r_1 = jax.vmap(_compute_r_σ, in_axes=(None, None, None, None, 0)) +r_σ_vmap = jax.vmap(r_1, in_axes=(None, None, None, 0, None)) +``` - # Unpack model - β, R, γ = params +Here's a fully vectorized version of $r_\sigma$. + +```{code-cell} ipython3 +def compute_r_σ(σ, params, sizes, arrays): w_size, y_size = sizes - w_grid, y_grid, Q = arrays + w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) + return r_σ_vmap(σ, params, arrays, w_indices, y_indices) - r_σ = compute_r_σ(σ, params, sizes, arrays) +compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,)) +``` + +Now we define the policy operator $T_\sigma$ going through similar steps - # Compute the array v[σ[i, j], jp] - yp_idx = jnp.arange(y_size) - yp_idx = jnp.reshape(yp_idx, (1, 1, y_size)) - σ = jnp.reshape(σ, (w_size, y_size, 1)) - V = v[σ, yp_idx] +```{code-cell} ipython3 +def _T_σ(v, σ, params, arrays, i, j): + "The σ-policy operator." - # Convert Q[j, jp] to Q[i, j, jp] - Q = jnp.reshape(Q, (1, y_size, y_size)) + # Unpack model + β, R, γ = params + w_grid, y_grid, Q = arrays + r_σ = _compute_r_σ(σ, params, arrays, i, j) # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp] - EV = jnp.sum(V * Q, axis=2) + EV = jnp.sum(v[σ[i, j], :] * Q[j, :]) return r_σ + β * EV -``` -and the Bellman operator $T$ -```{code-cell} ipython3 -def T(v, params, sizes, arrays): - "The Bellman operator." - return jnp.max(B(v, params, sizes, arrays), axis=2) -``` +T_1 = jax.vmap(_T_σ, in_axes=(None, None, None, None, None, 0)) +T_σ_vmap = jax.vmap(T_1, in_axes=(None, None, None, None, 0, None)) -The next function computes a $v$-greedy policy given $v$ +def T_σ(v, σ, params, sizes, arrays): + w_size, y_size = sizes + w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) + return T_σ_vmap(v, σ, params, arrays, w_indices, y_indices) -```{code-cell} ipython3 -def get_greedy(v, params, sizes, arrays): - "Computes a v-greedy policy, returned as a set of indices." - return jnp.argmax(B(v, params, sizes, arrays), axis=2) +T_σ = jax.jit(T_σ, static_argnums=(3,)) ``` + The function below computes the value $v_\sigma$ of following policy $\sigma$. This lifetime value is a function $v_\sigma$ that satisfies @@ -248,29 +282,28 @@ JAX allows us to solve linear systems defined in terms of operators; the first step is to define the function $L_{\sigma}$. ```{code-cell} ipython3 -def L_σ(v, σ, params, sizes, arrays): +def _L_σ(v, σ, params, arrays, i, j): """ Here we set up the linear map v -> L_σ v, where (L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′) """ - + # Unpack β, R, γ = params - w_size, y_size = sizes w_grid, y_grid, Q = arrays + # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp] + return v[i, j] - β * jnp.sum(v[σ[i, j], :] * Q[j, :]) - # Set up the array v[σ[i, j], jp] - zp_idx = jnp.arange(y_size) - zp_idx = jnp.reshape(zp_idx, (1, 1, y_size)) - σ = jnp.reshape(σ, (w_size, y_size, 1)) - V = v[σ, zp_idx] +L_1 = jax.vmap(_L_σ, in_axes=(None, None, None, None, None, 0)) +L_σ_vmap = jax.vmap(L_1, in_axes=(None, None, None, None, 0, None)) - # Expand Q[j, jp] to Q[i, j, jp] - Q = jnp.reshape(Q, (1, y_size, y_size)) +def L_σ(v, σ, params, sizes, arrays): + w_size, y_size = sizes + w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) + return L_σ_vmap(v, σ, params, arrays, w_indices, y_indices) - # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp] - return v - β * jnp.sum(V * Q, axis=2) +L_σ = jax.jit(L_σ, static_argnums=(3,)) ``` Now we can define a function to compute $v_{\sigma}$ @@ -290,20 +323,16 @@ def get_value(σ, params, sizes, arrays): partial_L_σ = lambda v: L_σ(v, σ, params, sizes, arrays) return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0] -``` - -## JIT compiled versions -```{code-cell} ipython3 -B = jax.jit(B, static_argnums=(2,)) -compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,)) -T = jax.jit(T, static_argnums=(2,)) -get_greedy = jax.jit(get_greedy, static_argnums=(2,)) get_value = jax.jit(get_value, static_argnums=(2,)) -T_σ = jax.jit(T_σ, static_argnums=(3,)) -L_σ = jax.jit(L_σ, static_argnums=(3,)) + ``` + + +## Iteration + + We use successive approximation for VFI. ```{code-cell} ipython3