diff --git a/lectures/aiyagari_jax.md b/lectures/aiyagari_jax.md index 37ae62e6..15246ab8 100644 --- a/lectures/aiyagari_jax.md +++ b/lectures/aiyagari_jax.md @@ -55,6 +55,7 @@ import matplotlib.pyplot as plt import numpy as np import jax import jax.numpy as jnp +from collections import namedtuple ``` Let's check the GPU we are running @@ -69,7 +70,6 @@ We will use 64 bit floats with JAX in order to increase the precision. jax.config.update("jax_enable_x64", True) ``` - We will use the following function to compute stationary distributions of stochastic matrices. (For a reference to the algorithm, see p. 88 of [Economic Dynamics](https://johnstachurski.net/edtc).) ```{code-cell} ipython3 @@ -84,7 +84,6 @@ def compute_stationary(P): return jnp.linalg.solve(A, jnp.ones(n)) ``` - ## Firms Firms produce output by hiring capital and labor. @@ -116,54 +115,61 @@ $$ The parameter $ \delta $ is the depreciation rate. -From the first-order condition with respect to capital, the firm’s inverse demand for capital is +These parameters are stored in the following namedtuple. +```{code-cell} ipython3 +Firm = namedtuple('Firm', ('A', 'N', 'α', 'β', 'δ')) -```{math} -:label: equation-aiy-rgk -r = A \alpha \left( \frac{N}{K} \right)^{1 - \alpha} - \delta +def create_firm(A=1.0, + N=1.0, + α=0.33, + β=0.96, + δ=0.05): + return Firm(A=A, N=N, α=α, β=β, δ=δ) ``` -Using this expression and the firm’s first-order condition for labor, we can pin down -the equilibrium wage rate as a function of $ r $ as +From the first-order condition with respect to capital, + +the firm’s inverse demand for capital is ```{math} -:label: equation-aiy-wgr -w(r) = A (1 - \alpha) (A \alpha / (r + \delta))^{\alpha / (1 - \alpha)} +:label: equation-aiy-rgk +r = A \alpha \left( \frac{N}{K} \right)^{1 - \alpha} - \delta ``` -These parameters and equations are stored in the following class. - ```{code-cell} ipython3 -class Firm: +def rd(K, f): + """ + Inverse demand curve for capital. The interest rate associated with a + given demand for capital K. + """ + A, N, α, β, δ = f + return A * α * (N / K)**(1 - α) - δ - def __init__(self, - A=1.0, - N=1.0, - α=0.33, - β=0.96, - δ=0.05): +rd = jax.jit(rd, static_argnums=(1,)) +``` - self.A, self.N, self.α, self.β, self.δ = A, N, α, β, δ +Using {eq}`equation-aiy-rgk` and the firm’s first-order condition for labor, - def rd(self, K): - """ - Inverse demand curve for capital. The interest rate associated with a - given demand for capital K. - """ - A, N, α, β, δ = self.A, self.N, self.α, self.β, self.δ - return A * α * (N / K)**(1 - α) - δ +we can pin down the equilibrium wage rate as a function of $ r $ as - def r_to_w(self, r): - """ - Equilibrium wages associated with a given interest rate r. - """ - A, N, α, β, δ = self.A, self.N, self.α, self.β, self.δ - return A * (1 - α) * (A * α / (r + δ))**(α / (1 - α)) +```{math} +:label: equation-aiy-wgr +w(r) = A (1 - \alpha) (A \alpha / (r + \delta))^{\alpha / (1 - \alpha)} ``` +```{code-cell} ipython3 +def r_to_w(r, f): + """ + Equilibrium wages associated with a given interest rate r. + """ + A, N, α, β, δ = f + return A * (1 - α) * (A * α / (r + δ))**(α / (1 - α)) + +r_to_w = jax.jit(r_to_w, static_argnums=(1,)) +``` ## Households @@ -212,38 +218,30 @@ For now we assume that $u(c) = \log(c)$. ### Primitives and Operators -This class stores the parameters that define a household asset +This namedtuple stores the parameters that define a household asset accumulation problem and the grids used to solve it. ```{code-cell} ipython3 -class Household: - - def __init__(self, - r=0.01, # Interest rate - w=1.0, # Wages - β=0.96, # Discount factor - Π=[[0.9, 0.1], [0.1, 0.9]], # Markov chain - z_grid=[0.1, 1.0], # Exogenous states - a_min=1e-10, a_max=20, # Asset grid - a_size=200): - - # Store values, set up grids over a and z - self.r, self.w, self.β = r, w, β - self.a_size = a_size - self.a_grid = jnp.linspace(a_min, a_max, a_size) - z_grid, Π = map(jnp.array, (z_grid, Π)) - self.Π = jax.device_put(Π) - self.z_grid = jax.device_put(z_grid) - self.z_size = len(z_grid) - - def constants(self): - return self.r, self.w, self.β - - def sizes(self): - return self.a_size, self.z_size +Household = namedtuple('Household', ('r', 'w', 'β', 'a_size', 'z_size', \ + 'a_grid', 'z_grid', 'Π')) - def arrays(self): - return self.a_grid, self.z_grid, self.Π +def create_household(r=0.01, # Interest rate + w=1.0, # Wages + β=0.96, # Discount factor + Π=[[0.9, 0.1], [0.1, 0.9]], # Markov chain + z_grid=[0.1, 1.0], # Exogenous states + a_min=1e-10, a_max=20, # Asset grid + a_size=200): + + a_grid = jnp.linspace(a_min, a_max, a_size) + z_grid, Π = map(jnp.array, (z_grid, Π)) + Π = jax.device_put(Π) + z_grid = jax.device_put(z_grid) + z_size = len(z_grid) + a_grid, z_grid, Π = jax.device_put((a_grid, z_grid, Π)) + + return Household(r=r, w=w, β=β, a_size=a_size, z_size=z_size, \ + a_grid=a_grid, z_grid=z_grid, Π=Π) ``` ```{code-cell} ipython3 @@ -252,7 +250,6 @@ def u(c): return jnp.log(c) ``` - This is the vectorized version of the right-hand side of the Bellman equation (before maximization), which is a 3D array representing @@ -285,7 +282,6 @@ def B(v, constants, sizes, arrays): B = jax.jit(B, static_argnums=(2,)) ``` - The next function computes greedy policies. ```{code-cell} ipython3 @@ -296,7 +292,6 @@ def get_greedy(v, constants, sizes, arrays): get_greedy = jax.jit(get_greedy, static_argnums=(2,)) ``` - We need to know rewards at a given policy for policy iteration. The following functions computes the array $r_{\sigma}$ which gives current @@ -327,7 +322,6 @@ def compute_r_σ(σ, constants, sizes, arrays): compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,)) ``` - The value $v_{\sigma}$ of a policy $\sigma$ is defined as $$ @@ -343,8 +337,10 @@ $$ $$ Defining the map as above works in a more intuitive multi-index setting -(e.g. working with $v[i, j]$ rather than flattening $v$ to a one-dimensional -array) and avoids instantiating the large matrix $P_{\sigma}$. + +(e.g. working with $v[i, j]$ rather than flattening $v$ to a one-dimensional array) + +and avoids instantiating the large matrix $P_{\sigma}$. The following linear operator is also needed for policy iteration. @@ -370,7 +366,6 @@ def R_σ(v, σ, constants, sizes, arrays): R_σ = jax.jit(R_σ, static_argnums=(3,)) ``` - The next function computes the lifetime value of a given policy. ```{code-cell} ipython3 @@ -387,7 +382,6 @@ def get_value(σ, constants, sizes, arrays): get_value = jax.jit(get_value, static_argnums=(2,)) ``` - The following function is used for optimistic policy iteration. ```{code-cell} ipython3 @@ -418,7 +412,6 @@ def T_σ(v, σ, constants, sizes, arrays): T_σ = jax.jit(T_σ, static_argnums=(3,)) ``` - ## Solvers We will solve the household problem using Howard policy iteration. @@ -426,9 +419,12 @@ We will solve the household problem using Howard policy iteration. ```{code-cell} ipython3 def policy_iteration(household, verbose=True): """Howard policy iteration routine.""" - constants = household.constants() - sizes = household.sizes() - arrays = household.arrays() + + γ, w, β, a_size, z_size, a_grid, z_grid, Π = household + + constants = γ, w, β + sizes = a_size, z_size + arrays = a_grid, z_grid, Π vz = jnp.zeros(sizes) σ = jnp.zeros(sizes, dtype=int) @@ -444,14 +440,16 @@ def policy_iteration(household, verbose=True): return σ ``` - We can also solve the problem using optimistic policy iteration. ```{code-cell} ipython3 def optimistic_policy_iteration(household, tol=1e-5, m=10): - constants = household.constants() - sizes = household.sizes() - arrays = household.arrays() + + γ, w, β, a_size, z_size, a_grid, z_grid, Π = household + + constants = γ, w, β + sizes = a_size, z_size + arrays = a_grid, z_grid, Π v = jnp.zeros(sizes) error = tol + 1 @@ -464,7 +462,6 @@ def optimistic_policy_iteration(household, tol=1e-5, m=10): return get_greedy(v, constants, sizes, arrays) ``` - As a first example of what we can do, let’s compute and plot an optimal accumulation policy at fixed prices. ```{code-cell} ipython3 @@ -473,7 +470,7 @@ r = 0.03 w = 0.956 # Create an instance of Housbehold -household = Household(r=r, w=w) +household = create_household(r=r, w=w) ``` ```{code-cell} ipython3 @@ -488,12 +485,10 @@ household = Household(r=r, w=w) σ_star = optimistic_policy_iteration(household) ``` - The next plot shows asset accumulation policies at different values of the exogenous state. ```{code-cell} ipython3 -a_size, z_size = household.sizes() -a_grid, z_grid, Π = household.arrays() +γ, w, β, a_size, z_size, a_grid, z_grid, Π = household fig, ax = plt.subplots(figsize=(9, 9)) ax.plot(a_grid, a_grid, 'k--') # 45 degrees @@ -507,7 +502,6 @@ ax.legend(loc='upper left') plt.show() ``` - ### Capital Supply To start thinking about equilibrium, we need to know how much capital households supply at a given interest rate $r$. @@ -552,24 +546,22 @@ compute_asset_stationary = jax.jit(compute_asset_stationary, static_argnums=(2,)) ``` - Let's give this a test run. ```{code-cell} ipython3 -constants = household.constants() -sizes = household.sizes() -arrays = household.arrays() +γ, w, β, a_size, z_size, a_grid, z_grid, Π = household +constants = γ, w, β +sizes = a_size, z_size +arrays = a_grid, z_grid, Π ψ = compute_asset_stationary(σ_star, constants, sizes, arrays) ``` - The distribution should sum to one: ```{code-cell} ipython3 ψ.sum() ``` - Now we are ready to compute capital supply by households given wages and interest rates. ```{code-cell} ipython3 @@ -577,9 +569,13 @@ def capital_supply(household): """ Map household decisions to the induced level of capital stock. """ - constants = household.constants() - sizes = household.sizes() - arrays = household.arrays() + + # Unpack + γ, w, β, a_size, z_size, a_grid, z_grid, Π = household + + constants = γ, w, β + sizes = a_size, z_size + arrays = a_grid, z_grid, Π # Compute the optimal policy σ_star = optimistic_policy_iteration(household) @@ -587,10 +583,9 @@ def capital_supply(household): ψ_a = compute_asset_stationary(σ_star, constants, sizes, arrays) # Return K - return float(jnp.sum(ψ_a * household.a_grid)) + return float(jnp.sum(ψ_a * a_grid)) ``` - ## Equilibrium We construct a *stationary rational expectations equilibrium* (SREE). @@ -631,8 +626,8 @@ The intersection gives equilibrium interest rates and capital. ```{code-cell} ipython3 # Create default instances -household = Household() -firm = Firm() +household = create_household() +firm = create_firm() # Create a grid of r values at which to compute demand and supply of capital num_points = 50 @@ -645,8 +640,8 @@ r_vals = np.linspace(0.005, 0.04, num_points) # Compute supply of capital k_vals = np.empty(num_points) for i, r in enumerate(r_vals): - household.r = r - household.w = firm.r_to_w(r) + # _replace create a new nametuple with the updated parameters + household = household._replace(r=r, w=r_to_w(r, firm)) k_vals[i] = capital_supply(household) ``` @@ -655,7 +650,7 @@ for i, r in enumerate(r_vals): fig, ax = plt.subplots() ax.plot(k_vals, r_vals, lw=2, alpha=0.6, label='supply of capital') -ax.plot(k_vals, firm.rd(k_vals), lw=2, alpha=0.6, label='demand for capital') +ax.plot(k_vals, rd(k_vals, firm), lw=2, alpha=0.6, label='demand for capital') ax.grid() ax.set_xlabel('capital') ax.set_ylabel('interest rate') @@ -668,12 +663,11 @@ Here's a plot of the excess demand function. The equilibrium is the zero (root) of this function. - ```{code-cell} ipython3 def excess_demand(K, firm, household): - r = firm.rd(K) - w = firm.r_to_w(r) - household.r, household.w = r, w + r = rd(K, firm) + w = r_to_w(r, firm) + household = household._replace(r=r, w=w) return K - capital_supply(household) ``` @@ -694,7 +688,6 @@ ax.legend() plt.show() ``` - ### Computing the equilibrium Now let's compute the equilibrium @@ -702,7 +695,6 @@ Now let's compute the equilibrium To do so, we use the bisection method, which is implemented in the next function. - ```{code-cell} ipython3 def bisect(f, a, b, *args, tol=10e-2): """ @@ -726,7 +718,7 @@ def bisect(f, a, b, *args, tol=10e-2): Now we call the bisection function on excess demand. ```{code-cell} ipython3 -def compute_equilibrium(household, firm): +def compute_equilibrium(firm, household): solution = bisect(excess_demand, 6.0, 10.0, firm, household) return solution ``` @@ -734,9 +726,9 @@ def compute_equilibrium(household, firm): ```{code-cell} ipython3 %%time -household = Household() -firm = Firm() -compute_equilibrium(household, firm) +household = create_household() +firm = create_firm() +compute_equilibrium(firm, household) ``` Notice how quickly we can compute the equilibrium capital stock using a simple @@ -764,9 +756,9 @@ showing the behaviour of equilibrium capital stock with the increase in $\beta$. eq_vals = np.empty_like(β_vals) for i, β in enumerate(β_vals): - household = Household(β=β) - firm = Firm(β=β) - eq_vals[i] = compute_equilibrium(household, firm) + household = create_household(β=β) + firm = create_firm(β=β) + eq_vals[i] = compute_equilibrium(firm, household) ``` ```{code-cell} ipython3 @@ -777,7 +769,6 @@ ax.set_ylabel('equilibrium') plt.show() ``` - ```{solution-end} ``` @@ -816,7 +807,6 @@ def u(c, γ=2): return c**(1 - γ) / (1 - γ) ``` - We need to re-compile all the jitted functions in order notice the change in the utility function. @@ -831,13 +821,12 @@ compute_asset_stationary = jax.jit(compute_asset_stationary, static_argnums=(2,)) ``` - Now, let's plot the the demand for capital by firms ```{code-cell} ipython3 # Create default instances -household = Household() -firm = Firm() +household = create_household() +firm = create_firm() # Create a grid of r values at which to compute demand and supply of capital num_points = 50 @@ -847,8 +836,7 @@ r_vals = np.linspace(0.005, 0.04, num_points) # Compute supply of capital k_vals = np.empty(num_points) for i, r in enumerate(r_vals): - household.r = r - household.w = firm.r_to_w(r) + household = household._replace(r=r, w=r_to_w(r, firm)) k_vals[i] = capital_supply(household) ``` @@ -857,7 +845,7 @@ for i, r in enumerate(r_vals): fig, ax = plt.subplots() ax.plot(k_vals, r_vals, lw=2, alpha=0.6, label='supply of capital') -ax.plot(k_vals, firm.rd(k_vals), lw=2, alpha=0.6, label='demand for capital') +ax.plot(k_vals, rd(k_vals, firm), lw=2, alpha=0.6, label='demand for capital') ax.grid() ax.set_xlabel('capital') ax.set_ylabel('interest rate') @@ -866,15 +854,14 @@ ax.legend() plt.show() ``` - Compute the equilibrium ```{code-cell} ipython3 %%time -household = Household() -firm = Firm() -compute_equilibrium(household, firm) +household = create_household() +firm = create_firm() +compute_equilibrium(firm, household) ``` ```{solution-end}