Skip to content

Commit

Permalink
Use JAX in successive_approx (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
kp992 authored Mar 1, 2024
1 parent d4e46d3 commit 4beb13a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 55 deletions.
38 changes: 19 additions & 19 deletions lectures/_static/lecture_specific/successive_approx.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
def successive_approx(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000, # Max iteration bound
print_step=25, # Print at multiples
verbose=False):
x = x_0
error = tolerance + 1
k = 1
while error > tolerance and k <= max_iter:
x_new = T(x)
def successive_approx_jax(x_0, # Initial condition
constants,
sizes,
arrays,
tolerance=1e-6, # Error tolerance
max_iter=10_000): # Max iteration bound

def body_fun(k_x_err):
k, x, error = k_x_err
x_new = T(x, constants, sizes, arrays)
error = jnp.max(jnp.abs(x_new - x))
if verbose and k % print_step == 0:
print(f"Completed iteration {k} with error {error}.")
x = x_new
k += 1
if error > tolerance:
print(f"Warning: Iteration hit upper bound {max_iter}.")
elif verbose:
print(f"Terminated successfully in {k} iterations.")
return k + 1, x_new, error

def cond_fun(k_x_err):
k, x, error = k_x_err
return jnp.logical_and(error > tolerance, k < max_iter)

k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1))
return x

successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(2,))
3 changes: 1 addition & 2 deletions lectures/_static/lecture_specific/vfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
_T = lambda v: T(v, constants, sizes, arrays)
vz = jnp.zeros(sizes)

v_star = successive_approx(_T, vz, tolerance=tol)
v_star = successive_approx_jax(vz, constants, sizes, arrays, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)
27 changes: 8 additions & 19 deletions lectures/opt_invest.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.5
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---


# Optimal Investment

```{include} _admonition/gpu.md
Expand Down Expand Up @@ -76,14 +75,6 @@ We will use 64 bit floats with JAX in order to increase the precision.
jax.config.update("jax_enable_x64", True)
```


We need the following successive approximation function.

```{code-cell} ipython3
:load: _static/lecture_specific/successive_approx.py
```


Let's define a function to create an investment model using the given parameters.

```{code-cell} ipython3
Expand Down Expand Up @@ -113,7 +104,6 @@ def create_investment_model(
return constants, sizes, arrays
```


Let's re-write the vectorized version of the right-hand side of the
Bellman equation (before maximization), which is a 3D array representing

Expand Down Expand Up @@ -183,7 +173,6 @@ def compute_r_σ(σ, constants, sizes, arrays):
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
```


Define the Bellman operator.

```{code-cell} ipython3
Expand All @@ -194,7 +183,6 @@ def T(v, constants, sizes, arrays):
T = jax.jit(T, static_argnums=(2,))
```


The following function computes a v-greedy policy.

```{code-cell} ipython3
Expand All @@ -205,7 +193,6 @@ def get_greedy(v, constants, sizes, arrays):
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
```


Define the $\sigma$-policy operator.

```{code-cell} ipython3
Expand Down Expand Up @@ -236,7 +223,6 @@ def T_σ(v, σ, constants, sizes, arrays):
T_σ = jax.jit(T_σ, static_argnums=(3,))
```


Next, we want to computes the lifetime value of following policy $\sigma$.

This lifetime value is a function $v_\sigma$ that satisfies
Expand Down Expand Up @@ -285,8 +271,7 @@ def L_σ(v, σ, constants, sizes, arrays):
L_σ = jax.jit(L_σ, static_argnums=(3,))
```

Now we can define a function to compute $v_{\sigma}$

Now we can define a function to compute $v_{\sigma}$

```{code-cell} ipython3
def get_value(σ, constants, sizes, arrays):
Expand All @@ -306,6 +291,11 @@ def get_value(σ, constants, sizes, arrays):
get_value = jax.jit(get_value, static_argnums=(2,))
```

We use successive approximation for VFI.

```{code-cell} ipython3
:load: _static/lecture_specific/successive_approx.py
```

Finally, we introduce the solvers that implement VFI, HPI and OPI.

Expand Down Expand Up @@ -355,7 +345,6 @@ print(out)
print(f"OPI completed in {elapsed} seconds.")
```


Here's the plot of the Howard policy, as a function of $y$ at the highest and lowest values of $z$.

```{code-cell} ipython3
Expand All @@ -377,7 +366,6 @@ ax.legend(fontsize=12)
plt.show()
```


Let's plot the time taken by each of the solvers and compare them.

```{code-cell} ipython3
Expand All @@ -403,6 +391,7 @@ print(f"VFI completed in {vfi_time} seconds.")

```{code-cell} ipython3
:tags: [hide-output]
opi_times = []
for m in m_vals:
print(f"Running optimistic policy iteration with m={m}.")
Expand Down
27 changes: 12 additions & 15 deletions lectures/opt_savings.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.5
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -65,20 +65,13 @@ where

$$ u(c) = \frac{c^{1-\gamma}}{1-\gamma} $$

+++

We use successive approximation for VFI.

```{code-cell} ipython3
:load: _static/lecture_specific/successive_approx.py
```

## Model primitives

First we define a model that stores parameters and grids

```{code-cell} ipython3
def create_consumption_model(R=1.01, # Gross interest rate
def create_consumption_model(R=1.01, # Gross interest rate
β=0.98, # Discount factor
γ=2, # CRRA parameter
w_min=0.01, # Min wealth
Expand Down Expand Up @@ -140,8 +133,6 @@ which is defined as the vector

$$ r_\sigma(w, y) := r(w, y, \sigma(w, y)) $$



```{code-cell} ipython3
def compute_r_σ(σ, constants, sizes, arrays):
"""
Expand Down Expand Up @@ -187,9 +178,9 @@ def T_σ(v, σ, constants, sizes, arrays):
Q = jnp.reshape(Q, (1, y_size, y_size))
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
Ev = jnp.sum(V * Q, axis=2)
EV = jnp.sum(V * Q, axis=2)
return r_σ + β * Ev
return r_σ + β * EV
```

and the Bellman operator $T$
Expand Down Expand Up @@ -260,7 +251,7 @@ def L_σ(v, σ, constants, sizes, arrays):
return v - β * jnp.sum(V * Q, axis=2)
```

Now we can define a function to compute $v_{\sigma}$
Now we can define a function to compute $v_{\sigma}$

```{code-cell} ipython3
def get_value(σ, constants, sizes, arrays):
Expand Down Expand Up @@ -291,6 +282,12 @@ T_σ = jax.jit(T_σ, static_argnums=(3,))
L_σ = jax.jit(L_σ, static_argnums=(3,))
```

We use successive approximation for VFI.

```{code-cell} ipython3
:load: _static/lecture_specific/successive_approx.py
```

## Solvers

Now we define the solvers, which implement VFI, HPI and OPI.
Expand Down Expand Up @@ -353,7 +350,7 @@ print("Starting VFI.")
start_time = time.time()
out = value_iteration(model)
elapsed = time.time() - start_time
print(f"VFI(jax not in succ) completed in {elapsed} seconds.")
print(f"VFI completed in {elapsed} seconds.")
```

```{code-cell} ipython3
Expand Down

0 comments on commit 4beb13a

Please sign in to comment.