From f6cc5089d6059c2ccc16f135e46375bf26e9b253 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Mon, 12 Aug 2024 08:52:18 +1000 Subject: [PATCH] misc --- lectures/job_search.md | 126 +++++++++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 31 deletions(-) diff --git a/lectures/job_search.md b/lectures/job_search.md index ecb9bf5..d5c4b0f 100644 --- a/lectures/job_search.md +++ b/lectures/job_search.md @@ -53,19 +53,24 @@ We study an elementary model where * the horizon is infinite * an unemployment agent discounts the future via discount factor $\beta \in (0,1)$ -The wage process obeys +### Set up + +The wage offer process obeys $$ - W_{t+1} = \rho W_t + \nu Z_{t+1}, - \qquad \{Z_t\} \text{ is IID and } N(0, 1) + W_{t+1} = \rho W_t + \nu Z_{t+1} $$ -We discretize this using Tauchen's method to produce a stochastic matrix $P$ +where $(Z_t)_{t \geq 0}$ is IID and standard normal. + +We discretize this wage process using Tauchen's method to produce a stochastic matrix $P$ + +### Rewards Since jobs are permanent, the return to accepting wage offer $w$ today is $$ - w + \beta w + \beta^2 w + \cdots = \frac{w}{1-\beta} + w + \beta w + \beta^2 w + \frac{w}{1-\beta} $$ The Bellman equation is @@ -79,8 +84,11 @@ $$ We solve this model using value function iteration. ++++ + +## Code -Let's set up a `namedtuple` to store information needed to solve the model. +Let's set up a namedtuple to store information needed to solve the model. ```{code-cell} ipython3 Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c')) @@ -94,15 +102,32 @@ def create_js_model( ρ=0.9, # wage persistence ν=0.2, # wage volatility β=0.99, # discount factor - c=1.0 # unemployment compensation + c=1.0, # unemployment compensation ): "Creates an instance of the job search model with Markov wages." mc = qe.tauchen(n, ρ, ν) - w_vals, P = jnp.exp(mc.state_values), mc.P - P = jnp.array(P) + w_vals, P = jnp.exp(mc.state_values), jnp.array(mc.P) return Model(n, w_vals, P, β, c) ``` +Let's test it: + +```{code-cell} ipython3 +model = create_js_model(β=0.98) +``` + +```{code-cell} ipython3 +model.c +``` + +```{code-cell} ipython3 +model.β +``` + +```{code-cell} ipython3 +model.w_vals.mean() +``` + Here's the Bellman operator. ```{code-cell} ipython3 @@ -135,13 +160,13 @@ $$ Here $\mathbf 1$ is an indicator function. -The statement above means that the worker accepts ($\sigma(w) = 1$) when the value of stopping -is higher than the value of continuing. +* $\sigma(w) = 1$ means stop +* $\sigma(w) = 0$ means continue. ```{code-cell} ipython3 @jax.jit def get_greedy(v, model): - """Get a v-greedy policy.""" + "Get a v-greedy policy." n, w_vals, P, β, c = model e = w_vals / (1 - β) h = c + β * P @ v @@ -153,8 +178,7 @@ Here's a routine for value function iteration. ```{code-cell} ipython3 def vfi(model, max_iter=10_000, tol=1e-4): - """Solve the infinite-horizon Markov job search model by VFI.""" - + "Solve the infinite-horizon Markov job search model by VFI." print("Starting VFI iteration.") v = jnp.zeros_like(model.w_vals) # Initial guess i = 0 @@ -171,7 +195,10 @@ def vfi(model, max_iter=10_000, tol=1e-4): return v_star, σ_star ``` -### Computing the solution + ++++ + +## Computing the solution Let's set up and solve the model. @@ -179,19 +206,32 @@ Let's set up and solve the model. model = create_js_model() n, w_vals, P, β, c = model -%time v_star, σ_star = vfi(model) +v_star, σ_star = vfi(model) ``` -We run it again to eliminate compile time. +Here's the optimal policy: ```{code-cell} ipython3 -%time v_star, σ_star = vfi(model) +fig, ax = plt.subplots() +ax.plot(σ_star) +ax.set_xlabel("wage values") +ax.set_ylabel("optimal choice (stop=1)") +plt.show() ``` We compute the reservation wage as the first $w$ such that $\sigma(w)=1$. ```{code-cell} ipython3 -res_wage = w_vals[jnp.searchsorted(σ_star, 1.0)] +stop_indices = jnp.where(σ_star == 1) +stop_indices +``` + +```{code-cell} ipython3 +res_wage_index = min(stop_indices[0]) +``` + +```{code-cell} ipython3 +res_wage = w_vals[res_wage_index] ``` ```{code-cell} ipython3 @@ -228,13 +268,37 @@ $$ $$ -When $\theta < 0$ the agent is risk sensitive. +When $\theta < 0$ the agent is risk averse. Solve the model when $\theta = -0.1$ and compare your result to the risk neutral case. Try to interpret your result. +You can start with the following code: + +```{code-cell} ipython3 +Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ')) +``` + +```{code-cell} ipython3 +def create_risk_sensitive_js_model( + n=500, # wage grid size + ρ=0.9, # wage persistence + ν=0.2, # wage volatility + β=0.99, # discount factor + c=1.0, # unemployment compensation + θ=-0.1 # risk parameter + ): + "Creates an instance of the job search model with Markov wages." + mc = qe.tauchen(n, ρ, ν) + w_vals, P = jnp.exp(mc.state_values), mc.P + P = jnp.array(P) + return Model(n, w_vals, P, β, c, θ) +``` + +Now you need to modify `T` and `get_greedy` and then run value function iteration again. + ```{exercise-end} ``` @@ -311,25 +375,25 @@ model_rs = create_risk_sensitive_js_model() n, w_vals, P, β, c, θ = model_rs -%time v_star_rs, σ_star_rs = vfi(model_rs) +v_star_rs, σ_star_rs = vfi(model_rs) ``` -We run it again to eliminate the compilation time. - -```{code-cell} ipython3 -%time v_star_rs, σ_star_rs = vfi(model_rs) -``` +Let's plot the results together with the original risk neutral case and see what we get. ```{code-cell} ipython3 -res_wage_rs = w_vals[jnp.searchsorted(σ_star_rs, 1.0)] +stop_indices = jnp.where(σ_star_rs == 1) +res_wage_index = min(stop_indices[0]) +res_wage_rs = w_vals[res_wage_index] ``` ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.plot(w_vals, v_star, alpha=0.8, label="RN $v$") -ax.plot(w_vals, v_star_rs, alpha=0.8, label="RS $v$") -ax.vlines((res_wage,), 150, 400, ls='--', color='darkblue', alpha=0.5, label=r"RV $\bar w$") -ax.vlines((res_wage_rs,), 150, 400, ls='--', color='orange', alpha=0.5, label=r"RS $\bar w$") +ax.plot(w_vals, v_star, alpha=0.8, label="risk neutral $v$") +ax.plot(w_vals, v_star_rs, alpha=0.8, label="risk sensitive $v$") +ax.vlines((res_wage,), 100, 400, ls='--', color='darkblue', + alpha=0.5, label=r"risk neutral $\bar w$") +ax.vlines((res_wage_rs,), 100, 400, ls='--', color='orange', + alpha=0.5, label=r"risk sensitive $\bar w$") ax.legend(frameon=False, fontsize=12, loc="lower right") ax.set_xlabel("$w$", fontsize=12) plt.show()