Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update job search lecture #190

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions lectures/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ kernelspec:

# Adventures with Autodiff


```{include} _admonition/gpu.md
```

## Overview

This lecture gives a brief introduction to automatic differentiation using
Google JAX.

Expand All @@ -25,14 +31,15 @@ powerful implementations available.
One of the best of these is the automatic differentiation routines contained
in JAX.

While other software packages also offer this feature, the JAX version is
particularly powerful because it integrates so well with other core
components of JAX (e.g., JIT compilation and parallelization).

As we will see in later lectures, automatic differentiation can be used not only
for AI but also for many problems faced in mathematical modeling, such as
multi-dimensional nonlinear optimization and root-finding problems.


```{include} _admonition/gpu.md
```

We need the following imports

```{code-cell} ipython3
Expand Down
144 changes: 109 additions & 35 deletions lectures/job_search.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ kernelspec:
```


In this lecture we study a basic infinite-horizon job search with Markov wage
In this lecture we study a basic infinite-horizon job search problem with Markov wage
draws

The exercise at the end asks you to add recursive preferences and compare
the result.
```{note}
For background on infinite horizon job search see, e.g., [DP1](https://dp.quantecon.org/).
```

The exercise at the end asks you to add risk-sensitive preferences and see how
the main results change.

In addition to what’s in Anaconda, this lecture will need the following libraries:

Expand Down Expand Up @@ -49,23 +53,32 @@ We study an elementary model where

* jobs are permanent
* unemployed workers receive current compensation $c$
* the wage offer distribution $\{W_t\}$ is Markovian
* the horizon is infinite
* an unemployment agent discounts the future via discount factor $\beta \in (0,1)$

The wage process obeys
### Set up

At the start of each period, an unemployed worker receives wage offer $W_t$.

To build a wage offer process we consider the dynamics

$$
W_{t+1} = \rho W_t + \nu Z_{t+1},
\qquad \{Z_t\} \text{ is IID and } N(0, 1)
W_{t+1} = \rho W_t + \nu Z_{t+1}
$$

We discretize this using Tauchen's method to produce a stochastic matrix $P$
where $(Z_t)_{t \geq 0}$ is IID and standard normal.

We then discretize this wage process using Tauchen's method to produce a stochastic matrix $P$.

Successive wage offers are drawn from $P$.

### Rewards

Since jobs are permanent, the return to accepting wage offer $w$ today is

$$
w + \beta w + \beta^2 w + \cdots = \frac{w}{1-\beta}
w + \beta w + \beta^2 w +
\cdots = \frac{w}{1-\beta}
$$

The Bellman equation is
Expand All @@ -79,30 +92,50 @@ $$

We solve this model using value function iteration.

+++

## Code

Let's set up a `namedtuple` to store information needed to solve the model.

```{code-cell} ipython3
Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c'))
```

The function below holds default values and populates the namedtuple.
The function below holds default values and populates the `namedtuple`.

```{code-cell} ipython3
def create_js_model(
n=500, # wage grid size
ρ=0.9, # wage persistence
ν=0.2, # wage volatility
β=0.99, # discount factor
c=1.0 # unemployment compensation
c=1.0, # unemployment compensation
):
"Creates an instance of the job search model with Markov wages."
mc = qe.tauchen(n, ρ, ν)
w_vals, P = jnp.exp(mc.state_values), mc.P
P = jnp.array(P)
w_vals, P = jnp.exp(mc.state_values), jnp.array(mc.P)
return Model(n, w_vals, P, β, c)
```

Let's test it:

```{code-cell} ipython3
model = create_js_model(β=0.98)
```

```{code-cell} ipython3
model.c
```

```{code-cell} ipython3
model.β
```

```{code-cell} ipython3
model.w_vals.mean()
```

Here's the Bellman operator.

```{code-cell} ipython3
Expand Down Expand Up @@ -135,13 +168,13 @@ $$

Here $\mathbf 1$ is an indicator function.

The statement above means that the worker accepts ($\sigma(w) = 1$) when the value of stopping
is higher than the value of continuing.
* $\sigma(w) = 1$ means stop
* $\sigma(w) = 0$ means continue.

```{code-cell} ipython3
@jax.jit
def get_greedy(v, model):
"""Get a v-greedy policy."""
"Get a v-greedy policy."
n, w_vals, P, β, c = model
e = w_vals / (1 - β)
h = c + β * P @ v
Expand All @@ -153,8 +186,7 @@ Here's a routine for value function iteration.

```{code-cell} ipython3
def vfi(model, max_iter=10_000, tol=1e-4):
"""Solve the infinite-horizon Markov job search model by VFI."""

"Solve the infinite-horizon Markov job search model by VFI."
print("Starting VFI iteration.")
v = jnp.zeros_like(model.w_vals) # Initial guess
i = 0
Expand All @@ -171,29 +203,47 @@ def vfi(model, max_iter=10_000, tol=1e-4):
return v_star, σ_star
```

### Computing the solution

+++

## Computing the solution

Let's set up and solve the model.

```{code-cell} ipython3
model = create_js_model()
n, w_vals, P, β, c = model

%time v_star, σ_star = vfi(model)
v_star, σ_star = vfi(model)
```

We run it again to eliminate compile time.
Here's the optimal policy:

```{code-cell} ipython3
%time v_star, σ_star = vfi(model)
fig, ax = plt.subplots()
ax.plot(σ_star)
ax.set_xlabel("wage values")
ax.set_ylabel("optimal choice (stop=1)")
plt.show()
```

We compute the reservation wage as the first $w$ such that $\sigma(w)=1$.

```{code-cell} ipython3
res_wage = w_vals[jnp.searchsorted(σ_star, 1.0)]
stop_indices = jnp.where(σ_star == 1)
stop_indices
```

```{code-cell} ipython3
res_wage_index = min(stop_indices[0])
```

```{code-cell} ipython3
res_wage = w_vals[res_wage_index]
```

Here's a joint plot of the value function and the reservation wage.

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.plot(w_vals, v_star, alpha=0.8, label="value function")
Expand Down Expand Up @@ -228,13 +278,37 @@ $$
$$


When $\theta < 0$ the agent is risk sensitive.
When $\theta < 0$ the agent is risk averse.

Solve the model when $\theta = -0.1$ and compare your result to the risk neutral
case.

Try to interpret your result.

You can start with the following code:

```{code-cell} ipython3

RiskModel = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ'))

def create_risk_sensitive_js_model(
n=500, # wage grid size
ρ=0.9, # wage persistence
ν=0.2, # wage volatility
β=0.99, # discount factor
c=1.0, # unemployment compensation
θ=-0.1 # risk parameter
):
"Creates an instance of the job search model with Markov wages."
mc = qe.tauchen(n, ρ, ν)
w_vals, P = jnp.exp(mc.state_values), mc.P
P = jnp.array(P)
return RiskModel(n, w_vals, P, β, c, θ)

```

Now you need to modify `T` and `get_greedy` and then run value function iteration again.

```{exercise-end}
```

Expand Down Expand Up @@ -311,25 +385,25 @@ model_rs = create_risk_sensitive_js_model()

n, w_vals, P, β, c, θ = model_rs

%time v_star_rs, σ_star_rs = vfi(model_rs)
v_star_rs, σ_star_rs = vfi(model_rs)
```

We run it again to eliminate the compilation time.

```{code-cell} ipython3
%time v_star_rs, σ_star_rs = vfi(model_rs)
```
Let's plot the results together with the original risk neutral case and see what we get.

```{code-cell} ipython3
res_wage_rs = w_vals[jnp.searchsorted(σ_star_rs, 1.0)]
stop_indices = jnp.where(σ_star_rs == 1)
res_wage_index = min(stop_indices[0])
res_wage_rs = w_vals[res_wage_index]
```

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.plot(w_vals, v_star, alpha=0.8, label="RN $v$")
ax.plot(w_vals, v_star_rs, alpha=0.8, label="RS $v$")
ax.vlines((res_wage,), 150, 400, ls='--', color='darkblue', alpha=0.5, label=r"RV $\bar w$")
ax.vlines((res_wage_rs,), 150, 400, ls='--', color='orange', alpha=0.5, label=r"RS $\bar w$")
ax.plot(w_vals, v_star, alpha=0.8, label="risk neutral $v$")
ax.plot(w_vals, v_star_rs, alpha=0.8, label="risk sensitive $v$")
ax.vlines((res_wage,), 100, 400, ls='--', color='darkblue',
alpha=0.5, label=r"risk neutral $\bar w$")
ax.vlines((res_wage_rs,), 100, 400, ls='--', color='orange',
alpha=0.5, label=r"risk sensitive $\bar w$")
ax.legend(frameon=False, fontsize=12, loc="lower right")
ax.set_xlabel("$w$", fontsize=12)
plt.show()
Expand Down
10 changes: 3 additions & 7 deletions lectures/newtons_method.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@ kernelspec:

One of the key features of JAX is automatic differentiation.

While other software packages also offer this feature, the JAX version is
particularly powerful because it integrates so closely with other core
components of JAX, such as accelerated linear algebra, JIT compilation and
parallelization.
We introduced this feature in {doc}`autodiff`.

The application of automatic differentiation we consider is computing economic equilibria via Newton's method.
In this lecture we apply automatic differentiation to the problem of computing economic equilibria via Newton's method.

Newton's method is a relatively simple root and fixed point solution algorithm, which we discussed
in [a more elementary QuantEcon lecture](https://python.quantecon.org/newton_method.html).

JAX is almost ideally suited to implementing Newton's method efficiently, even
in high dimensions.
JAX is ideally suited to implementing Newton's method efficiently, even in high dimensions.

We use the following imports in this lecture

Expand Down
Loading