Skip to content

Commit

Permalink
Use callback in successive_approx
Browse files Browse the repository at this point in the history
  • Loading branch information
kp992 committed Mar 6, 2024
1 parent 1cd0306 commit 463ce9f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
11 changes: 3 additions & 8 deletions lectures/_static/lecture_specific/successive_approx.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
def successive_approx_jax(x_0, # Initial condition
constants,
sizes,
arrays,
def successive_approx_jax(T, # Operator (callable)
x_0, # Initial condition
tolerance=1e-6, # Error tolerance
max_iter=10_000): # Max iteration bound

def body_fun(k_x_err):
k, x, error = k_x_err
x_new = T(x, constants, sizes, arrays)
x_new = T(x)
error = jnp.max(jnp.abs(x_new - x))
return k + 1, x_new, error

Expand All @@ -17,5 +14,3 @@ def cond_fun(k_x_err):

k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1))
return x

successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(2,))
4 changes: 2 additions & 2 deletions lectures/_static/lecture_specific/vfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
def value_iteration(model, tol=1e-5):
constants, sizes, arrays = model
vz = jnp.zeros(sizes)

v_star = successive_approx_jax(vz, constants, sizes, arrays, tolerance=tol)
_T = lambda v: T(v, constants, sizes, arrays)
v_star = successive_approx_jax(_T, vz, tolerance=tol)
return get_greedy(v_star, constants, sizes, arrays)

0 comments on commit 463ce9f

Please sign in to comment.