Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
jstac committed Aug 11, 2024
1 parent 021d61b commit f6cc508
Showing 1 changed file with 95 additions and 31 deletions.
126 changes: 95 additions & 31 deletions lectures/job_search.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,24 @@ We study an elementary model where
* the horizon is infinite
* an unemployment agent discounts the future via discount factor $\beta \in (0,1)$

The wage process obeys
### Set up

The wage offer process obeys

$$
W_{t+1} = \rho W_t + \nu Z_{t+1},
\qquad \{Z_t\} \text{ is IID and } N(0, 1)
W_{t+1} = \rho W_t + \nu Z_{t+1}
$$

We discretize this using Tauchen's method to produce a stochastic matrix $P$
where $(Z_t)_{t \geq 0}$ is IID and standard normal.

We discretize this wage process using Tauchen's method to produce a stochastic matrix $P$

### Rewards

Since jobs are permanent, the return to accepting wage offer $w$ today is

$$
w + \beta w + \beta^2 w + \cdots = \frac{w}{1-\beta}
w + \beta w + \beta^2 w + \frac{w}{1-\beta}
$$

The Bellman equation is
Expand All @@ -79,8 +84,11 @@ $$

We solve this model using value function iteration.

+++

## Code

Let's set up a `namedtuple` to store information needed to solve the model.
Let's set up a namedtuple to store information needed to solve the model.

```{code-cell} ipython3
Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c'))
Expand All @@ -94,15 +102,32 @@ def create_js_model(
ρ=0.9, # wage persistence
ν=0.2, # wage volatility
β=0.99, # discount factor
c=1.0 # unemployment compensation
c=1.0, # unemployment compensation
):
"Creates an instance of the job search model with Markov wages."
mc = qe.tauchen(n, ρ, ν)
w_vals, P = jnp.exp(mc.state_values), mc.P
P = jnp.array(P)
w_vals, P = jnp.exp(mc.state_values), jnp.array(mc.P)
return Model(n, w_vals, P, β, c)
```

Let's test it:

```{code-cell} ipython3
model = create_js_model(β=0.98)
```

```{code-cell} ipython3
model.c
```

```{code-cell} ipython3
model.β
```

```{code-cell} ipython3
model.w_vals.mean()
```

Here's the Bellman operator.

```{code-cell} ipython3
Expand Down Expand Up @@ -135,13 +160,13 @@ $$

Here $\mathbf 1$ is an indicator function.

The statement above means that the worker accepts ($\sigma(w) = 1$) when the value of stopping
is higher than the value of continuing.
* $\sigma(w) = 1$ means stop
* $\sigma(w) = 0$ means continue.

```{code-cell} ipython3
@jax.jit
def get_greedy(v, model):
"""Get a v-greedy policy."""
"Get a v-greedy policy."
n, w_vals, P, β, c = model
e = w_vals / (1 - β)
h = c + β * P @ v
Expand All @@ -153,8 +178,7 @@ Here's a routine for value function iteration.

```{code-cell} ipython3
def vfi(model, max_iter=10_000, tol=1e-4):
"""Solve the infinite-horizon Markov job search model by VFI."""
"Solve the infinite-horizon Markov job search model by VFI."
print("Starting VFI iteration.")
v = jnp.zeros_like(model.w_vals) # Initial guess
i = 0
Expand All @@ -171,27 +195,43 @@ def vfi(model, max_iter=10_000, tol=1e-4):
return v_star, σ_star
```

### Computing the solution

+++

## Computing the solution

Let's set up and solve the model.

```{code-cell} ipython3
model = create_js_model()
n, w_vals, P, β, c = model
%time v_star, σ_star = vfi(model)
v_star, σ_star = vfi(model)
```

We run it again to eliminate compile time.
Here's the optimal policy:

```{code-cell} ipython3
%time v_star, σ_star = vfi(model)
fig, ax = plt.subplots()
ax.plot(σ_star)
ax.set_xlabel("wage values")
ax.set_ylabel("optimal choice (stop=1)")
plt.show()
```

We compute the reservation wage as the first $w$ such that $\sigma(w)=1$.

```{code-cell} ipython3
res_wage = w_vals[jnp.searchsorted(σ_star, 1.0)]
stop_indices = jnp.where(σ_star == 1)
stop_indices
```

```{code-cell} ipython3
res_wage_index = min(stop_indices[0])
```

```{code-cell} ipython3
res_wage = w_vals[res_wage_index]
```

```{code-cell} ipython3
Expand Down Expand Up @@ -228,13 +268,37 @@ $$
$$


When $\theta < 0$ the agent is risk sensitive.
When $\theta < 0$ the agent is risk averse.

Solve the model when $\theta = -0.1$ and compare your result to the risk neutral
case.

Try to interpret your result.

You can start with the following code:

```{code-cell} ipython3
Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ'))
```

```{code-cell} ipython3
def create_risk_sensitive_js_model(
n=500, # wage grid size
ρ=0.9, # wage persistence
ν=0.2, # wage volatility
β=0.99, # discount factor
c=1.0, # unemployment compensation
θ=-0.1 # risk parameter
):
"Creates an instance of the job search model with Markov wages."
mc = qe.tauchen(n, ρ, ν)
w_vals, P = jnp.exp(mc.state_values), mc.P
P = jnp.array(P)
return Model(n, w_vals, P, β, c, θ)
```

Now you need to modify `T` and `get_greedy` and then run value function iteration again.

```{exercise-end}
```

Expand Down Expand Up @@ -311,25 +375,25 @@ model_rs = create_risk_sensitive_js_model()
n, w_vals, P, β, c, θ = model_rs
%time v_star_rs, σ_star_rs = vfi(model_rs)
v_star_rs, σ_star_rs = vfi(model_rs)
```

We run it again to eliminate the compilation time.

```{code-cell} ipython3
%time v_star_rs, σ_star_rs = vfi(model_rs)
```
Let's plot the results together with the original risk neutral case and see what we get.

```{code-cell} ipython3
res_wage_rs = w_vals[jnp.searchsorted(σ_star_rs, 1.0)]
stop_indices = jnp.where(σ_star_rs == 1)
res_wage_index = min(stop_indices[0])
res_wage_rs = w_vals[res_wage_index]
```

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.plot(w_vals, v_star, alpha=0.8, label="RN $v$")
ax.plot(w_vals, v_star_rs, alpha=0.8, label="RS $v$")
ax.vlines((res_wage,), 150, 400, ls='--', color='darkblue', alpha=0.5, label=r"RV $\bar w$")
ax.vlines((res_wage_rs,), 150, 400, ls='--', color='orange', alpha=0.5, label=r"RS $\bar w$")
ax.plot(w_vals, v_star, alpha=0.8, label="risk neutral $v$")
ax.plot(w_vals, v_star_rs, alpha=0.8, label="risk sensitive $v$")
ax.vlines((res_wage,), 100, 400, ls='--', color='darkblue',
alpha=0.5, label=r"risk neutral $\bar w$")
ax.vlines((res_wage_rs,), 100, 400, ls='--', color='orange',
alpha=0.5, label=r"risk sensitive $\bar w$")
ax.legend(frameon=False, fontsize=12, loc="lower right")
ax.set_xlabel("$w$", fontsize=12)
plt.show()
Expand Down

0 comments on commit f6cc508

Please sign in to comment.