From 6a852621ebfd4f7e1dc79df2f4ca6a1fb9fa14db Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 17 Jun 2024 10:27:01 +1000 Subject: [PATCH 1/3] misc --- lectures/aiyagari_jax.md | 459 +++++++++++++++++---------------------- 1 file changed, 203 insertions(+), 256 deletions(-) diff --git a/lectures/aiyagari_jax.md b/lectures/aiyagari_jax.md index a0a2013..d2cd966 100644 --- a/lectures/aiyagari_jax.md +++ b/lectures/aiyagari_jax.md @@ -51,11 +51,11 @@ A less sophisticated version of this lecture (without JAX) can be found We use the following imports ```{code-cell} ipython3 +import time import matplotlib.pyplot as plt import numpy as np import jax import jax.numpy as jnp -from time import time from collections import namedtuple ``` @@ -74,8 +74,6 @@ jax.config.update("jax_enable_x64", True) We will use the following function to compute stationary distributions of stochastic matrices. (For a reference to the algorithm, see p. 88 of [Economic Dynamics](https://johnstachurski.net/edtc).) ```{code-cell} ipython3 -# Compute the stationary distribution of P by matrix inversion. - @jax.jit def compute_stationary(P): n = P.shape[0] @@ -98,20 +96,20 @@ Hence we can consider a single (but nonetheless competitive) representative firm The firm’s output is $$ -Y_t = A K_t^{\alpha} N^{1 - \alpha} +Y = A K^{\alpha} N^{1 - \alpha} $$ where - $ A $ and $ \alpha $ are parameters with $ A > 0 $ and $ \alpha \in (0, 1) $ -- $ K_t $ is aggregate capital +- $ K$ is aggregate capital - $ N $ is total labor supply (which is constant in this simple version of the model) The firm’s problem is $$ -\max_{K, N} \left\{ A K_t^{\alpha} N^{1 - \alpha} - (r + \delta) K - w N \right\} +\max_{K, N} \left\{ A K^{\alpha} N^{1 - \alpha} - (r + \delta) K - w N \right\} $$ The parameter $ \delta $ is the depreciation rate. @@ -119,20 +117,21 @@ The parameter $ \delta $ is the depreciation rate. These parameters are stored in the following namedtuple. ```{code-cell} ipython3 -Firm = namedtuple('Firm', ('A', 'N', 'α', 'β', 'δ')) +Firm = namedtuple('Firm', ('A', 'N', 'α', 'δ')) def create_firm(A=1.0, N=1.0, α=0.33, - β=0.96, δ=0.05): - - return Firm(A=A, N=N, α=α, β=β, δ=δ) + """ + Create a namedtuple that stores firm data. + + """ + return Firm(A=A, N=N, α=α, δ=δ) ``` -From the first-order condition with respect to capital, +From the first-order condition with respect to capital, the firm’s inverse demand for capital is -the firm’s inverse demand for capital is ```{math} @@ -146,13 +145,12 @@ def r_given_k(K, firm): Inverse demand curve for capital. The interest rate associated with a given demand for capital K. """ - A, N, α, β, δ = firm + A, N, α, δ = firm return A * α * (N / K)**(1 - α) - δ ``` Using {eq}`equation-aiy-rgk` and the firm’s first-order condition for labor, -we can pin down the equilibrium wage rate as a function of $ r $ as ```{math} @@ -161,11 +159,11 @@ w(r) = A (1 - \alpha) (A \alpha / (r + \delta))^{\alpha / (1 - \alpha)} ``` ```{code-cell} ipython3 -def r_to_w(r, f): +def r_to_w(r, firm): """ Equilibrium wages associated with a given interest rate r. """ - A, N, α, β, δ = f + A, N, α, δ = firm return A * (1 - α) * (A * α / (r + δ))**(α / (1 - α)) ``` @@ -208,46 +206,56 @@ In this simple version of the model, households supply labor inelastically beca Below we provide code to solve the household problem, taking $r$ and $w$ as fixed. -For now we assume that $u(c) = \log(c)$. - -(CRRA utility is treated in the exercises.) - ### Primitives and Operators -This namedtuple stores the parameters that define a household asset -accumulation problem and the grids used to solve it. +We will solve the household problem using Howard policy iteration +(see Ch 5 of [Dynamic Programming](https://dp.quantecon.org/)). + +First we set up a namedtuple to store the parameters that define a household asset +accumulation problem, as well as the grids used to solve it. ```{code-cell} ipython3 -Household = namedtuple('Household', ('r', 'w', 'β', 'a_size', 'z_size', \ - 'a_grid', 'z_grid', 'Π')) +Household = namedtuple('Household', + ('β', 'a_grid', 'z_grid', 'Π')) +``` -def create_household(r=0.01, # Interest rate - w=1.0, # Wages - β=0.96, # Discount factor +```{code-cell} ipython3 +def create_household(β=0.96, # Discount factor Π=[[0.9, 0.1], [0.1, 0.9]], # Markov chain z_grid=[0.1, 1.0], # Exogenous states a_min=1e-10, a_max=20, # Asset grid a_size=200): - + """ + Create a namedtuple that stores all data needed to solve the household + problem, given prices. + + """ a_grid = jnp.linspace(a_min, a_max, a_size) z_grid, Π = map(jnp.array, (z_grid, Π)) - Π = jax.device_put(Π) - z_grid = jax.device_put(z_grid) - z_size = len(z_grid) - a_grid, z_grid, Π = jax.device_put((a_grid, z_grid, Π)) - - return Household(r=r, w=w, β=β, a_size=a_size, z_size=z_size, \ - a_grid=a_grid, z_grid=z_grid, Π=Π) + return Household(β=β, a_grid=a_grid, z_grid=z_grid, Π=Π) ``` +For now we assume that $u(c) = \log(c)$. + +(CRRA utility is treated in the exercises.) + ```{code-cell} ipython3 -def u(c): - return jnp.log(c) +u = jnp.log ``` -This is the vectorized version of the right-hand side of the Bellman equation +Here's a tuple that stores the wage rate and interest rate, as well as a function that creates a price namedtuple with default values. + +```{code-cell} ipython3 +Prices = namedtuple('Prices', ('r', 'w')) + +def create_prices(r=0.01, # Interest rate + w=1.0): # Wages + return Prices(r=r, w=w) +``` + +Now we set up a vectorized version of the right-hand side of the Bellman equation (before maximization), which is a 3D array representing $$ @@ -256,67 +264,68 @@ $$ for all $(a, z, a')$. ```{code-cell} ipython3 -def B(v, constants, sizes, arrays): +@jax.jit +def B(v, household, prices): # Unpack - r, w, β = constants - a_size, z_size = sizes - a_grid, z_grid, Π = arrays + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) + r, w = prices # Compute current consumption as array c[i, j, ip] a = jnp.reshape(a_grid, (a_size, 1, 1)) # a[i] -> a[i, j, ip] z = jnp.reshape(z_grid, (1, z_size, 1)) # z[j] -> z[i, j, ip] ap = jnp.reshape(a_grid, (1, 1, a_size)) # ap[ip] -> ap[i, j, ip] - c = w*z + (1 + r)*a - ap + c = w * z + (1 + r) * a - ap # Calculate continuation rewards at all combinations of (a, z, ap) v = jnp.reshape(v, (1, 1, a_size, z_size)) # v[ip, jp] -> v[i, j, ip, jp] Π = jnp.reshape(Π, (1, z_size, 1, z_size)) # Π[j, jp] -> Π[i, j, ip, jp] - EV = jnp.sum(v * Π, axis=3) # sum over last index jp + EV = jnp.sum(v * Π, axis=-1) # sum over last index jp # Compute the right-hand side of the Bellman equation return jnp.where(c > 0, u(c) + β * EV, -jnp.inf) - -B = jax.jit(B, static_argnums=(2,)) ``` The next function computes greedy policies. ```{code-cell} ipython3 -# Computes a v-greedy policy, returned as a set of indices -def get_greedy(v, constants, sizes, arrays): - return jnp.argmax(B(v, constants, sizes, arrays), axis=2) +@jax.jit +def get_greedy(v, household, prices): + """ + Computes a v-greedy policy σ, returned as a set of indices. If + σ[i, j] equals ip, then a_grid[ip] is the maximizer at i, j. -get_greedy = jax.jit(get_greedy, static_argnums=(2,)) + """ + return jnp.argmax(B(v, household, prices), axis=-1) # argmax over ap ``` -We need to know rewards at a given policy for policy iteration. - -The following functions computes the array $r_{\sigma}$ which gives current +The following function computes the array $r_{\sigma}$ which gives current rewards given policy $\sigma$. -That is, +```{code-cell} ipython3 +@jax.jit +def compute_r_σ(σ, household, prices): + """ + Compute current rewards at each i, j under policy σ. In particular, -$$ - r_{\sigma}[i, j] = r[i, j, \sigma[i, j]] -$$ + r_σ[i, j] = u((1 + r)a[i] + wz[j] - a'[ip]) -```{code-cell} ipython3 -def compute_r_σ(σ, constants, sizes, arrays): + when ip = σ[i, j]. + + """ # Unpack - r, w, β = constants - a_size, z_size = sizes - a_grid, z_grid, Π = arrays + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) + r, w = prices # Compute r_σ[i, j] a = jnp.reshape(a_grid, (a_size, 1)) z = jnp.reshape(z_grid, (1, z_size)) ap = a_grid[σ] - c = (1 + r)*a + w*z - ap + c = (1 + r) * a + w * z - ap r_σ = u(c) return r_σ - -compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,)) ``` The value $v_{\sigma}$ of a policy $\sigma$ is defined as @@ -325,79 +334,75 @@ $$ v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma} $$ -Here we set up the linear map $v \rightarrow R_{\sigma} v$, where $R_{\sigma} := I - \beta P_{\sigma}$. +(See Ch 5 of [Dynamic Programming](https://dp.quantecon.org/) for notation and background on Howard policy iteration.) + +To compute this vector, we set up the linear map $v \rightarrow R_{\sigma} v$, where $R_{\sigma} := I - \beta P_{\sigma}$. -In the consumption problem, this map can be expressed as +This map can be expressed as $$ (R_{\sigma} v)(a, z) = v(a, z) - \beta \sum_{z'} v(\sigma(a, z), z') Π(z, z') $$ -Defining the map as above works in a more intuitive multi-index setting - -(e.g. working with $v[i, j]$ rather than flattening $v$ to a one-dimensional array) - -and avoids instantiating the large matrix $P_{\sigma}$. +(Notice that $R_\sigma$ is expressed as a linear operator rather than a matrix -- this is much easier and cleaner to code, and also exploits sparsity.) ```{code-cell} ipython3 -def R_σ(v, σ, constants, sizes, arrays): +@jax.jit +def R_σ(v, σ, household): # Unpack - r, w, β = constants - a_size, z_size = sizes - a_grid, z_grid, Π = arrays + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) # Set up the array v[σ[i, j], jp] zp_idx = jnp.arange(z_size) zp_idx = jnp.reshape(zp_idx, (1, 1, z_size)) σ = jnp.reshape(σ, (a_size, z_size, 1)) V = v[σ, zp_idx] - + # Expand Π[j, jp] to Π[i, j, jp] Π = jnp.reshape(Π, (1, z_size, z_size)) - + # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Π[j, jp] - return v - β * jnp.sum(V * Π, axis=2) - -R_σ = jax.jit(R_σ, static_argnums=(3,)) + return v - β * jnp.sum(V * Π, axis=-1) ``` The next function computes the lifetime value of a given policy. ```{code-cell} ipython3 -# Get the value v_σ of policy σ by inverting the linear map R_σ +@jax.jit +def get_value(σ, household, prices): + """ + Get the lifetime value of policy σ by computing -def get_value(σ, constants, sizes, arrays): + v_σ = R_σ^{-1} r_σ - r_σ = compute_r_σ(σ, constants, sizes, arrays) + """ + r_σ = compute_r_σ(σ, household, prices) # Reduce R_σ to a function in v - partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays) - # Compute inverse v_σ = (I - β P_σ)^{-1} r_σ - return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0] - -get_value = jax.jit(get_value, static_argnums=(2,)) + _R_σ = lambda v: R_σ(v, σ, household) + # Compute v_σ = R_σ^{-1} r_σ using an iterative routing. + return jax.scipy.sparse.linalg.bicgstab(_R_σ, r_σ)[0] ``` -## Solvers - -We will solve the household problem using Howard policy iteration. +Here's the Howard policy iteration. ```{code-cell} ipython3 -def policy_iteration(household, tol=1e-4, max_iter=10_000, verbose=False): - """Howard policy iteration routine.""" - - γ, w, β, a_size, z_size, a_grid, z_grid, Π = household - - constants = γ, w, β - sizes = a_size, z_size - arrays = a_grid, z_grid, Π +def howard_policy_iteration(household, prices, + tol=1e-4, max_iter=10_000, verbose=False): + """ + Howard policy iteration routine. - σ = jnp.zeros(sizes, dtype=int) - v_σ = get_value(σ, constants, sizes, arrays) + """ + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) + σ = jnp.zeros((a_size, z_size), dtype=int) + + v_σ = get_value(σ, household, prices) i = 0 error = tol + 1 while error > tol and i < max_iter: - σ_new = get_greedy(v_σ, constants, sizes, arrays) - v_σ_new = get_value(σ_new, constants, sizes, arrays) + σ_new = get_greedy(v_σ, household, prices) + v_σ_new = get_value(σ_new, household, prices) error = jnp.max(jnp.abs(v_σ_new - v_σ)) σ = σ_new v_σ = v_σ_new @@ -410,26 +415,27 @@ def policy_iteration(household, tol=1e-4, max_iter=10_000, verbose=False): As a first example of what we can do, let’s compute and plot an optimal accumulation policy at fixed prices. ```{code-cell} ipython3 -# Create an instance of Housbehold +# Create an instance of Household household = create_household() +prices = create_prices() ``` ```{code-cell} ipython3 -%%time -σ_star = policy_iteration(household, verbose=True).block_until_ready() +r, w = prices ``` -We run it again to get rid of compile time. +```{code-cell} ipython3 +r, w +``` ```{code-cell} ipython3 -%%time -σ_star = policy_iteration(household, verbose=True).block_until_ready() +%time σ_star = howard_policy_iteration(household, prices, verbose=True) ``` The next plot shows asset accumulation policies at different values of the exogenous state. ```{code-cell} ipython3 -γ, w, β, a_size, z_size, a_grid, z_grid, Π = household +β, a_grid, z_grid, Π = household fig, ax = plt.subplots() ax.plot(a_grid, a_grid, 'k--', label="45 degrees") @@ -449,19 +455,18 @@ To start thinking about equilibrium, we need to know how much capital households This quantity can be calculated by taking the stationary distribution of assets under the optimal policy and computing the mean. -The next function implements this calculation for a given policy $\sigma$. +The next function computes the stationary distribution for a given policy $\sigma$ via the following steps: -First we compute the stationary distribution of $P_{\sigma}$, which is for the -bivariate Markov chain of the state $(a_t, z_t)$. Then we sum out -$z_t$ to get the marginal distribution for $a_t$. +* compute the stationary distribution $\psi = (\psi(a, z))$ of $P_{\sigma}$, which defines the + Markov chain of the state $(a_t, z_t)$ under policy $\sigma$. +* sum out $z_t$ to get the marginal distribution for $a_t$. ```{code-cell} ipython3 -def compute_asset_stationary(σ, constants, sizes, arrays): - +@jax.jit +def compute_asset_stationary(σ, household): # Unpack - r, w, β = constants - a_size, z_size = sizes - a_grid, z_grid, Π = arrays + β, a_grid, z_grid, Π = household + a_size, z_size = len(a_grid), len(z_grid) # Construct P_σ as an array of the form P_σ[i, j, ip, jp] ap_idx = jnp.arange(a_size) @@ -475,7 +480,7 @@ def compute_asset_stationary(σ, constants, sizes, arrays): n = a_size * z_size P_σ = jnp.reshape(P_σ, (n, n)) - # Get stationary distribution and reshape onto [i, j] grid + # Get stationary distribution and reshape back onto [i, j] grid ψ = compute_stationary(P_σ) ψ = jnp.reshape(ψ, (a_size, z_size)) @@ -490,138 +495,86 @@ compute_asset_stationary = jax.jit(compute_asset_stationary, Let's give this a test run. ```{code-cell} ipython3 -γ, w, β, a_size, z_size, a_grid, z_grid, Π = household -constants = γ, w, β -sizes = a_size, z_size -arrays = a_grid, z_grid, Π -ψ = compute_asset_stationary(σ_star, constants, sizes, arrays) +ψ_a = compute_asset_stationary(σ_star, household) +``` + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.bar(household.a_grid, ψ_a) +ax.set_xlabel("asset level") +ax.set_ylabel("probability mass") +plt.show() ``` The distribution should sum to one: ```{code-cell} ipython3 -ψ.sum() +ψ_a.sum() ``` -Now we are ready to compute capital supply by households given wages and interest rates. +The next function computes aggregate capital supply by households under policy $\sigma$, given wages and interest rates. ```{code-cell} ipython3 -def capital_supply(household): - """ - Map household decisions to the induced level of capital stock. +def capital_supply(σ, household): """ - - # Unpack - γ, w, β, a_size, z_size, a_grid, z_grid, Π = household + Induced level of capital stock under the policy, taking r and w as given. - constants = γ, w, β - sizes = a_size, z_size - arrays = a_grid, z_grid, Π - - # Compute the optimal policy - σ_star = policy_iteration(household) - # Compute the stationary distribution - ψ_a = compute_asset_stationary(σ_star, constants, sizes, arrays) - - # Return K + """ + β, a_grid, z_grid, Π = household + ψ_a = compute_asset_stationary(σ, household) return float(jnp.sum(ψ_a * a_grid)) ``` ## Equilibrium -We construct a *stationary rational expectations equilibrium* (SREE). - -In such an equilibrium - -- prices induce behavior that generates aggregate quantities consistent with the prices -- aggregate quantities and prices are constant over time - +We compute a **stationary rational expectations equilibrium** (SREE) as follows: -In more detail, an SREE lists a set of prices, savings and production policies such that - -- households want to choose the specified savings policies taking the prices as given -- firms maximize profits taking the same prices as given -- the resulting aggregate quantities are consistent with the prices; in particular, the demand for capital equals the supply -- aggregate quantities (defined as cross-sectional averages) are constant - - -In practice, once parameter values are set, we can check for an SREE by the following steps - -1. pick a proposed quantity $ K $ for aggregate capital -2. determine corresponding prices, with interest rate $ r $ determined by {eq}`equation-aiy-rgk` and a wage rate $ w(r) $ as given in {eq}`equation-aiy-wgr`. -3. determine the common optimal savings policy of the households given these prices -4. compute aggregate capital as the mean of steady state capital given this savings policy - - -If this final quantity agrees with $ K $ then we have a SREE. Otherwise we adjust $K$. - -These steps describe a fixed point problem which we solve below. - -### Visual inspection +1. set $n=0$, start with initial guess $ K_0$ for aggregate capital +2. determine prices $r, w$ from the firm decision problem, given $K_n$ +3. compute the optimal savings policy of the households given these prices +4. compute aggregate capital $K_{n+1}$ as the mean of steady state capital given this savings policy +5. if $K_{n+1} \approx K_n$ stop, otherwise go to step 2. + +We can write the sequence of operations in steps 2-4 as -Let’s inspect visually as a first pass. - -The following code draws aggregate supply and demand curves for capital. - -The intersection gives equilibrium interest rates and capital. - -```{code-cell} ipython3 -# Create default instances -household = create_household() -firm = create_firm() - -# Create a grid of r values at which to compute demand and supply of capital -num_points = 50 -r_vals = np.linspace(0.005, 0.04, num_points) -``` - -```{code-cell} ipython3 -%%time - -# Compute supply of capital -k_vals = np.empty(num_points) -for i, r in enumerate(r_vals): - # _replace create a new nametuple with the updated parameters - household = household._replace(r=r, w=r_to_w(r, firm)) - k_vals[i] = capital_supply(household) -``` - -```{code-cell} ipython3 -# Plot against demand for capital by firms - -fig, ax = plt.subplots() -ax.plot(k_vals, r_vals, lw=2, alpha=0.6, label='supply of capital') -ax.plot(k_vals, r_given_k(k_vals, firm), lw=2, alpha=0.6, label='demand for capital') -ax.set_xlabel('capital') -ax.set_ylabel('interest rate') -ax.legend(loc='upper right') +$$ +K_{n + 1} = G(K_n) +$$ -plt.show() -``` +If $K_{n+1}$ agrees with $K_n$ then we have a SREE. -Here's a plot of the excess demand function. +In other words, our problem is to find the fixed-point of the one-dimensional map $G$. -The equilibrium is the zero (root) of this function. +Here's $G$ expressed as a Python function: ```{code-cell} ipython3 -def excess_demand(K, firm, household): +def G(K, firm, household): + # Get prices r, w associated with K r = r_given_k(K, firm) w = r_to_w(r, firm) - household = household._replace(r=r, w=w) - return K - capital_supply(household) + # Generate a household object with these prices, compute + # aggregate capital. + prices = create_prices(r=r, w=w) + σ_star = howard_policy_iteration(household, prices) + return capital_supply(σ_star, household) ``` +### Visual inspection + +Let’s inspect visually as a first pass. + ```{code-cell} ipython3 -%%time num_points = 50 +firm = create_firm() +household = create_household() k_vals = np.linspace(4, 12, num_points) -out = [excess_demand(k, firm, household) for k in k_vals] +out = [G(k, firm, household) for k in k_vals] ``` ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.plot(k_vals, out, lw=2, alpha=0.6, label='excess demand') -ax.plot(k_vals, np.zeros_like(k_vals), 'k--', label="45") +ax.plot(k_vals, out, lw=2, alpha=0.6, label='$G$') +ax.plot(k_vals, k_vals, 'k--', label="45 degrees") ax.set_xlabel('capital') ax.legend() plt.show() @@ -629,58 +582,52 @@ plt.show() ### Computing the equilibrium -Now let's compute the equilibrium +Now let's compute the equilibrium. -To do so, we use the bisection method, which is implemented -in the next function. +Looking at the figure above, we see that a simple iteration scheme $K_{n+1} = G(K_n)$ will cycle from high to low values, leading to slow convergence. -```{code-cell} ipython3 -def bisect(f, a, b, *args, tol=10e-2): - """ - Implements the bisection root finding algorithm, assuming that f is a - real-valued function on [a, b] satisfying f(a) < 0 < f(b). - """ - lower, upper = a, b - count = 0 - while upper - lower > tol and count < 10000: - middle = 0.5 * (upper + lower) - if f(middle, *args) > 0: # root is between lower and middle - lower, upper = lower, middle - else: # root is between middle and upper - lower, upper = middle, upper - count += 1 - if count == 10000: - print("Root might not be accurate") - return 0.5 * (upper + lower), count -``` +As a result, we use a damped iteration scheme of the form -Now we call the bisection function on excess demand. +$$K_{n+1} = \alpha K_n + (1-\alpha) G(K_n)$$ ```{code-cell} ipython3 -def compute_equilibrium(firm, household): - print("\nComputing equilibrium capital stock") - solution, count = bisect(excess_demand, 6.0, 10.0, firm, household) - - start = time() - solution, count = bisect(excess_demand, 6.0, 10.0, firm, household) - bisect_without_compile = time() - start - print(f"Computed equilibrium in {count} iterations and {bisect_without_compile} seconds") - return solution +def compute_equilibrium(firm, household, + K0=6, α=0.99, max_iter=1_000, tol=1e-4, + print_skip=10, verbose=False): + n = 0 + K = K0 + error = tol + 1 + while error > tol and n < max_iter: + new_K = α * K + (1 - α) * G(K, firm, household) + error = abs(new_K - K) + K = new_K + n += 1 + if verbose and n % print_skip == 0: + print(f"At iteration {n} with error {error}") + return K, n ``` ```{code-cell} ipython3 -%%time -household = create_household() firm = create_firm() -compute_equilibrium(firm, household) +household = create_household() +print("\nComputing equilibrium capital stock") +start = time.time() +K_star, n = compute_equilibrium(firm, household, K0=6.0, verbose=True) +elapsed = time.time() - start +print(f"Computed equilibrium {K_star:.5} in {n} iterations and {elapsed} seconds") ``` -Notice how quickly we can compute the equilibrium capital stock using a simple -method such as bisection. +This is not very fast, given how quickly we can solve the household problem. + +You can try varying $\alpha$, but usually this parameter is hard to set a priori. + +In the exercises below you will be asked to use bisection instead, which generally performs better. ++++ ## Exercises ++++ ```{exercise-start} :label: aygr_ex1 From f634cc50c88a2ece22decee012f87157af03164c Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 17 Jun 2024 12:43:37 +1000 Subject: [PATCH 2/3] misc --- lectures/aiyagari_jax.md | 129 +++++++++++++++------------------------ 1 file changed, 48 insertions(+), 81 deletions(-) diff --git a/lectures/aiyagari_jax.md b/lectures/aiyagari_jax.md index d2cd966..0077b35 100644 --- a/lectures/aiyagari_jax.md +++ b/lectures/aiyagari_jax.md @@ -488,8 +488,6 @@ def compute_asset_stationary(σ, household): ψ_a = jnp.sum(ψ, axis=1) return ψ_a -compute_asset_stationary = jax.jit(compute_asset_stationary, - static_argnums=(2,)) ``` Let's give this a test run. @@ -630,125 +628,94 @@ In the exercises below you will be asked to use bisection instead, which general +++ ```{exercise-start} -:label: aygr_ex1 ``` -Using the default household and firm model, produce a graph -showing the behaviour of equilibrium capital stock with the increase in $\beta$. + +Write a new version of `compute_equilibrium` that uses `bisect` from `scipy.optimize` instead of damped iteration. + +See if you can make it faster that the previous version. + +In `bisect`, + +* you should set `xtol=1e-4` to have the same error tolerance as the previous version. +* for the lower and upper bounds of the bisection routine try `a = 1.0` and `b = 20.0`. ```{exercise-end} ``` -```{solution-start} aygr_ex1 + +```{solution-start} :class: dropdown ``` ```{code-cell} ipython3 -β_vals = np.linspace(0.9, 0.99, 40) -eq_vals = np.empty_like(β_vals) - -for i, β in enumerate(β_vals): - household = create_household(β=β) - firm = create_firm(β=β) - eq_vals[i] = compute_equilibrium(firm, household) +from scipy.optimize import bisect ``` -```{code-cell} ipython3 -fig, ax = plt.subplots() -ax.plot(β_vals, eq_vals, ms=2) -ax.set_xlabel(r'$\beta$') -ax.set_ylabel('equilibrium') -plt.show() -``` +We use bisection to find the zero of the function $h(k) = k - G(k)$. -```{solution-end} +```{code-cell} ipython3 +def compute_equilibrium(firm, household, a=1.0, b=20.0): + K = bisect(lambda k: k - G(k, firm, household), a, b, xtol=1e-4) + return K ``` - -```{exercise-start} -:label: aygr_ex2 +```{code-cell} ipython3 +firm = create_firm() +household = create_household() +print("\nComputing equilibrium capital stock") +start = time.time() +K_star = compute_equilibrium(firm, household) +elapsed = time.time() - start +print(f"Computed equilibrium capital stock {K_star:.5} in {elapsed} seconds") ``` -Switch to the CRRA utility function +Bisection seems to be faster than the damped iteration scheme. -$$ - u(c) =\frac{c^{1-\gamma}}{1-\gamma} -$$ -and re-do the plot of demand for capital by firms against the -supply of captial. - -Also, recompute the equilibrium. +```{solution-end} +``` -Use the default parameters for households and firms. -Set $\gamma=2$. -```{exercise-end} +```{exercise-start} ``` -```{solution-start} aygr_ex2 -:class: dropdown -``` +Show how equilibrium capital stock changes with $\beta$. -Let's define the utility function +Use the following values of $\beta$ and plot the relationship you find. ```{code-cell} ipython3 -def u(c, γ=2): - return c**(1 - γ) / (1 - γ) +β_vals = np.linspace(0.94, 0.98, 20) ``` -We need to re-compile all the jitted functions in order notice the change -in the utility function. - -```{code-cell} ipython3 -B = jax.jit(B, static_argnums=(2,)) -get_greedy = jax.jit(get_greedy, static_argnums=(2,)) -compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,)) -R_σ = jax.jit(R_σ, static_argnums=(3,)) -get_value = jax.jit(get_value, static_argnums=(2,)) -compute_asset_stationary = jax.jit(compute_asset_stationary, - static_argnums=(2,)) +```{exercise-end} ``` -Now, let's plot the the demand for capital by firms -```{code-cell} ipython3 -# Create default instances -household = create_household() -firm = create_firm() -# Create a grid of r values at which to compute demand and supply of capital -num_points = 50 -r_vals = np.linspace(0.005, 0.04, num_points) - -# Compute supply of capital -k_vals = np.empty(num_points) -for i, r in enumerate(r_vals): - household = household._replace(r=r, w=r_to_w(r, firm)) - k_vals[i] = capital_supply(household) +```{solution-start} +:class: dropdown ``` ```{code-cell} ipython3 -# Plot against demand for capital by firms +K_vals = np.empty_like(β_vals) +K = 6.0 # initial guess -fig, ax = plt.subplots() -ax.plot(k_vals, r_vals, lw=2, alpha=0.6, label='supply of capital') -ax.plot(k_vals, r_given_k(k_vals, firm), lw=2, alpha=0.6, label='demand for capital') -ax.set_xlabel('capital') -ax.set_ylabel('interest rate') -ax.legend() - -plt.show() +for i, β in enumerate(β_vals): + household = create_household(β=β) + K = compute_equilibrium(firm, household, 0.5 * K, 1.5 * K) + print(f"Computed equilibrium {K:.4} at β = {β}") + K_vals[i] = K ``` -Compute the equilibrium - ```{code-cell} ipython3 -%%time -household = create_household() -firm = create_firm() -compute_equilibrium(firm, household) +fig, ax = plt.subplots() +ax.plot(β_vals, K_vals, ms=2) +ax.set_xlabel(r'$\beta$') +ax.set_ylabel('capital') +plt.show() ``` ```{solution-end} ``` + From d44f79b38d3dafb70865441337736e124186f363 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 20 Jun 2024 17:00:16 +1000 Subject: [PATCH 3/3] @mmcky fixes --- lectures/aiyagari_jax.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lectures/aiyagari_jax.md b/lectures/aiyagari_jax.md index 0077b35..240dd8d 100644 --- a/lectures/aiyagari_jax.md +++ b/lectures/aiyagari_jax.md @@ -586,7 +586,9 @@ Looking at the figure above, we see that a simple iteration scheme $K_{n+1} = G( As a result, we use a damped iteration scheme of the form -$$K_{n+1} = \alpha K_n + (1-\alpha) G(K_n)$$ +$$ +K_{n+1} = \alpha K_n + (1-\alpha) G(K_n) +$$ ```{code-cell} ipython3 def compute_equilibrium(firm, household, @@ -621,13 +623,11 @@ You can try varying $\alpha$, but usually this parameter is hard to set a priori In the exercises below you will be asked to use bisection instead, which generally performs better. -+++ ## Exercises -+++ - ```{exercise-start} +:label: aiyagari-ex1 ``` Write a new version of `compute_equilibrium` that uses `bisect` from `scipy.optimize` instead of damped iteration. @@ -643,7 +643,7 @@ In `bisect`, ``` -```{solution-start} +```{solution-start} aiyagari-ex1 :class: dropdown ``` @@ -678,6 +678,7 @@ Bisection seems to be faster than the damped iteration scheme. ```{exercise-start} +:label: aiyagari-ex2 ``` Show how equilibrium capital stock changes with $\beta$. @@ -693,7 +694,7 @@ Use the following values of $\beta$ and plot the relationship you find. -```{solution-start} +```{solution-start} aiyagari-ex2 :class: dropdown ```