From 4beb13a338007b24c20beb6bf7ba235d204ea560 Mon Sep 17 00:00:00 2001 From: kp992 <145801876+kp992@users.noreply.github.com> Date: Sat, 2 Mar 2024 01:41:03 +0530 Subject: [PATCH] Use JAX in successive_approx (#134) --- .../lecture_specific/successive_approx.py | 38 +++++++++---------- lectures/_static/lecture_specific/vfi.py | 3 +- lectures/opt_invest.md | 27 ++++--------- lectures/opt_savings.md | 27 ++++++------- 4 files changed, 40 insertions(+), 55 deletions(-) diff --git a/lectures/_static/lecture_specific/successive_approx.py b/lectures/_static/lecture_specific/successive_approx.py index 18ddbf4d..5d64d82f 100644 --- a/lectures/_static/lecture_specific/successive_approx.py +++ b/lectures/_static/lecture_specific/successive_approx.py @@ -1,21 +1,21 @@ -def successive_approx(T, # Operator (callable) - x_0, # Initial condition - tolerance=1e-6, # Error tolerance - max_iter=10_000, # Max iteration bound - print_step=25, # Print at multiples - verbose=False): - x = x_0 - error = tolerance + 1 - k = 1 - while error > tolerance and k <= max_iter: - x_new = T(x) +def successive_approx_jax(x_0, # Initial condition + constants, + sizes, + arrays, + tolerance=1e-6, # Error tolerance + max_iter=10_000): # Max iteration bound + + def body_fun(k_x_err): + k, x, error = k_x_err + x_new = T(x, constants, sizes, arrays) error = jnp.max(jnp.abs(x_new - x)) - if verbose and k % print_step == 0: - print(f"Completed iteration {k} with error {error}.") - x = x_new - k += 1 - if error > tolerance: - print(f"Warning: Iteration hit upper bound {max_iter}.") - elif verbose: - print(f"Terminated successfully in {k} iterations.") + return k + 1, x_new, error + + def cond_fun(k_x_err): + k, x, error = k_x_err + return jnp.logical_and(error > tolerance, k < max_iter) + + k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1)) return x + +successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(2,)) diff --git a/lectures/_static/lecture_specific/vfi.py b/lectures/_static/lecture_specific/vfi.py index d4f1f7b7..e96a048f 100644 --- a/lectures/_static/lecture_specific/vfi.py +++ b/lectures/_static/lecture_specific/vfi.py @@ -2,8 +2,7 @@ def value_iteration(model, tol=1e-5): constants, sizes, arrays = model - _T = lambda v: T(v, constants, sizes, arrays) vz = jnp.zeros(sizes) - v_star = successive_approx(_T, vz, tolerance=tol) + v_star = successive_approx_jax(vz, constants, sizes, arrays, tolerance=tol) return get_greedy(v_star, constants, sizes, arrays) diff --git a/lectures/opt_invest.md b/lectures/opt_invest.md index 20f1e6fa..ca01135c 100644 --- a/lectures/opt_invest.md +++ b/lectures/opt_invest.md @@ -4,14 +4,13 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- - # Optimal Investment ```{include} _admonition/gpu.md @@ -76,14 +75,6 @@ We will use 64 bit floats with JAX in order to increase the precision. jax.config.update("jax_enable_x64", True) ``` - -We need the following successive approximation function. - -```{code-cell} ipython3 -:load: _static/lecture_specific/successive_approx.py -``` - - Let's define a function to create an investment model using the given parameters. ```{code-cell} ipython3 @@ -113,7 +104,6 @@ def create_investment_model( return constants, sizes, arrays ``` - Let's re-write the vectorized version of the right-hand side of the Bellman equation (before maximization), which is a 3D array representing @@ -183,7 +173,6 @@ def compute_r_σ(σ, constants, sizes, arrays): compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,)) ``` - Define the Bellman operator. ```{code-cell} ipython3 @@ -194,7 +183,6 @@ def T(v, constants, sizes, arrays): T = jax.jit(T, static_argnums=(2,)) ``` - The following function computes a v-greedy policy. ```{code-cell} ipython3 @@ -205,7 +193,6 @@ def get_greedy(v, constants, sizes, arrays): get_greedy = jax.jit(get_greedy, static_argnums=(2,)) ``` - Define the $\sigma$-policy operator. ```{code-cell} ipython3 @@ -236,7 +223,6 @@ def T_σ(v, σ, constants, sizes, arrays): T_σ = jax.jit(T_σ, static_argnums=(3,)) ``` - Next, we want to computes the lifetime value of following policy $\sigma$. This lifetime value is a function $v_\sigma$ that satisfies @@ -285,8 +271,7 @@ def L_σ(v, σ, constants, sizes, arrays): L_σ = jax.jit(L_σ, static_argnums=(3,)) ``` -Now we can define a function to compute $v_{\sigma}$ - +Now we can define a function to compute $v_{\sigma}$ ```{code-cell} ipython3 def get_value(σ, constants, sizes, arrays): @@ -306,6 +291,11 @@ def get_value(σ, constants, sizes, arrays): get_value = jax.jit(get_value, static_argnums=(2,)) ``` +We use successive approximation for VFI. + +```{code-cell} ipython3 +:load: _static/lecture_specific/successive_approx.py +``` Finally, we introduce the solvers that implement VFI, HPI and OPI. @@ -355,7 +345,6 @@ print(out) print(f"OPI completed in {elapsed} seconds.") ``` - Here's the plot of the Howard policy, as a function of $y$ at the highest and lowest values of $z$. ```{code-cell} ipython3 @@ -377,7 +366,6 @@ ax.legend(fontsize=12) plt.show() ``` - Let's plot the time taken by each of the solvers and compare them. ```{code-cell} ipython3 @@ -403,6 +391,7 @@ print(f"VFI completed in {vfi_time} seconds.") ```{code-cell} ipython3 :tags: [hide-output] + opi_times = [] for m in m_vals: print(f"Running optimistic policy iteration with m={m}.") diff --git a/lectures/opt_savings.md b/lectures/opt_savings.md index de1fd9f2..7468fe93 100644 --- a/lectures/opt_savings.md +++ b/lectures/opt_savings.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -65,20 +65,13 @@ where $$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$ -+++ - -We use successive approximation for VFI. - -```{code-cell} ipython3 -:load: _static/lecture_specific/successive_approx.py -``` ## Model primitives First we define a model that stores parameters and grids ```{code-cell} ipython3 -def create_consumption_model(R=1.01, # Gross interest rate +def create_consumption_model(R=1.01, # Gross interest rate β=0.98, # Discount factor γ=2, # CRRA parameter w_min=0.01, # Min wealth @@ -140,8 +133,6 @@ which is defined as the vector $$ r_\sigma(w, y) := r(w, y, \sigma(w, y)) $$ - - ```{code-cell} ipython3 def compute_r_σ(σ, constants, sizes, arrays): """ @@ -187,9 +178,9 @@ def T_σ(v, σ, constants, sizes, arrays): Q = jnp.reshape(Q, (1, y_size, y_size)) # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp] - Ev = jnp.sum(V * Q, axis=2) + EV = jnp.sum(V * Q, axis=2) - return r_σ + β * Ev + return r_σ + β * EV ``` and the Bellman operator $T$ @@ -260,7 +251,7 @@ def L_σ(v, σ, constants, sizes, arrays): return v - β * jnp.sum(V * Q, axis=2) ``` -Now we can define a function to compute $v_{\sigma}$ +Now we can define a function to compute $v_{\sigma}$ ```{code-cell} ipython3 def get_value(σ, constants, sizes, arrays): @@ -291,6 +282,12 @@ T_σ = jax.jit(T_σ, static_argnums=(3,)) L_σ = jax.jit(L_σ, static_argnums=(3,)) ``` +We use successive approximation for VFI. + +```{code-cell} ipython3 +:load: _static/lecture_specific/successive_approx.py +``` + ## Solvers Now we define the solvers, which implement VFI, HPI and OPI. @@ -353,7 +350,7 @@ print("Starting VFI.") start_time = time.time() out = value_iteration(model) elapsed = time.time() - start_time -print(f"VFI(jax not in succ) completed in {elapsed} seconds.") +print(f"VFI completed in {elapsed} seconds.") ``` ```{code-cell} ipython3