-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[arellano] get rid of jitclass #131
Changes from 2 commits
2fed0fb
8a1f149
762521e
f0a5874
ff86e0f
520add0
f4af18e
3a3d23a
90bf9c7
0dfa2b3
cf734ef
a3723f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ jupytext: | |
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
jupytext_version: 1.14.5 | ||
jupytext_version: 1.15.2 | ||
kernelspec: | ||
display_name: Python 3 (ipykernel) | ||
language: python | ||
|
@@ -77,6 +77,7 @@ import random | |
|
||
import jax | ||
import jax.numpy as jnp | ||
from collections import namedtuple | ||
``` | ||
|
||
Let's check the GPU we are running | ||
|
@@ -365,55 +366,43 @@ The output process is discretized using a [quadrature method due to Tauchen](htt | |
|
||
As we have in other places, we accelerate our code using Numba. | ||
|
||
We define a class that will store parameters, grids and transition | ||
We define a namedtuple to store parameters, grids and transition | ||
probabilities. | ||
|
||
```{code-cell} ipython3 | ||
:hide-output: false | ||
Arellano_Economy = namedtuple('Arellano_Economy', ('β', 'γ', 'r', 'ρ', 'η', 'θ', \ | ||
'B_size', 'y_size', \ | ||
'P', 'B_grid', 'y_grid', 'def_y')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted above, we can migrate the comments below to here. (Also, we can change lines without |
||
``` | ||
|
||
class Arellano_Economy: | ||
" Stores data and creates primitives for the Arellano economy. " | ||
|
||
def __init__(self, | ||
B_grid_size=251, # Grid size for bonds | ||
B_grid_min=-0.45, # Smallest B value | ||
B_grid_max=0.45, # Largest B value | ||
y_grid_size=51, # Grid size for income | ||
β=0.953, # Time discount parameter | ||
γ=2.0, # Utility parameter | ||
r=0.017, # Lending rate | ||
ρ=0.945, # Persistence in the income process | ||
η=0.025, # Standard deviation of the income process | ||
θ=0.282, # Prob of re-entering financial markets | ||
def_y_param=0.969): # Parameter governing income in default | ||
|
||
# Save parameters | ||
self.β, self.γ, self.r, = β, γ, r | ||
self.ρ, self.η, self.θ = ρ, η, θ | ||
|
||
# Set up grids | ||
self.y_grid_size = y_grid_size | ||
self.B_grid_size = B_grid_size | ||
B_grid = jnp.linspace(B_grid_min, B_grid_max, B_grid_size) | ||
mc = qe.markov.tauchen(y_grid_size, ρ, η) | ||
y_grid, P = jnp.exp(mc.state_values), mc.P | ||
|
||
# Put grids on the device | ||
self.B_grid = jax.device_put(B_grid) | ||
self.y_grid = jax.device_put(y_grid) | ||
self.P = jax.device_put(P) | ||
|
||
# Output received while in default, with same shape as y_grid | ||
self.def_y = jnp.minimum(def_y_param * jnp.mean(self.y_grid), self.y_grid) | ||
|
||
def params(self): | ||
return self.β, self.γ, self.r, self.ρ, self.η, self.θ | ||
|
||
def sizes(self): | ||
return self.B_grid_size, self.y_grid_size | ||
|
||
def arrays(self): | ||
return self.P, self.B_grid, self.y_grid, self.def_y | ||
```{code-cell} ipython3 | ||
def create_arellano(B_size=251, # Grid size for bonds | ||
B_min=-0.45, # Smallest B value | ||
B_max=0.45, # Largest B value | ||
y_size=51, # Grid size for income | ||
β=0.953, # Time discount parameter | ||
γ=2.0, # Utility parameter | ||
r=0.017, # Lending rate | ||
ρ=0.945, # Persistence in the income process | ||
η=0.025, # Standard deviation of the income process | ||
θ=0.282, # Prob of re-entering financial markets | ||
def_y_param=0.969): # Parameter governing income in default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we prefer to comment on parameters when they are defined for the first time. Please move these comments on parameters to where we define the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @HumphreyYang , I will add the comments to both places, given the comments from another PR. |
||
# Set up grids | ||
B_grid = jnp.linspace(B_min, B_max, B_size) | ||
mc = qe.markov.tauchen(y_size, ρ, η) | ||
y_grid, P = jnp.exp(mc.state_values), mc.P | ||
|
||
# Put grids on the device | ||
B_grid = jax.device_put(B_grid) | ||
y_grid = jax.device_put(y_grid) | ||
P = jax.device_put(P) | ||
|
||
# Output received while in default, with same shape as y_grid | ||
def_y = jnp.minimum(def_y_param * jnp.mean(y_grid), y_grid) | ||
|
||
return Arellano_Economy(β=β, γ=γ, r=r, ρ=ρ, η=η, θ=θ, B_size=B_size, \ | ||
y_size=y_size, P=P, B_grid=B_grid, y_grid=y_grid, \ | ||
def_y=def_y) | ||
``` | ||
|
||
Here is the utility function. | ||
|
@@ -473,6 +462,7 @@ def T_d(v_c, v_d, params, sizes, arrays): | |
β, γ, r, ρ, η, θ = params | ||
B_size, y_size = sizes | ||
P, B_grid, y_grid, def_y = arrays | ||
|
||
B0_idx = jnp.searchsorted(B_grid, 1e-10) # Index at which B is near zero | ||
|
||
current_utility = u(def_y, γ) | ||
|
@@ -519,7 +509,6 @@ def bellman(v_c, v_d, q, params, sizes, arrays): | |
# Return new_v_c[i_B, i_y, i_Bp] | ||
val = jnp.where(c > 0, u(c, γ) + β * continuation_value, -jnp.inf) | ||
return val | ||
|
||
``` | ||
|
||
```{code-cell} ipython3 | ||
|
@@ -558,8 +547,8 @@ def update_values_and_prices(v_c, v_d, params, sizes, arrays): | |
return new_v_c, new_v_d | ||
``` | ||
|
||
We can now write a function that will use the `Arellano_Economy` class and the | ||
functions defined above to compute the solution to our model. | ||
We can now write a function that will use an instance of `Arellano_Economy` and | ||
the functions defined above to compute the solution to our model. | ||
|
||
One of the jobs of this function is to take an instance of | ||
`Arellano_Economy`, which is hard for the JIT compiler to handle, and strip it | ||
|
@@ -570,14 +559,16 @@ down to more basic objects, which are then passed out to jitted functions. | |
|
||
def solve(model, tol=1e-8, max_iter=10_000): | ||
""" | ||
Given an instance of Arellano_Economy, this function computes the optimal | ||
Given an instance of `Arellano_Economy`, this function computes the optimal | ||
policy and value functions. | ||
""" | ||
# Unpack | ||
params = model.params() | ||
sizes = model.sizes() | ||
arrays = model.arrays() | ||
B_size, y_size = sizes | ||
|
||
β, γ, r, ρ, η, θ, B_size, y_size, P, B_grid, y_grid, def_y = model | ||
|
||
params = β, γ, r, ρ, η, θ | ||
sizes = B_size, y_size | ||
arrays = P, B_grid, y_grid, def_y | ||
|
||
# Initial conditions for v_c and v_d | ||
v_c = jnp.zeros((B_size, y_size)) | ||
|
@@ -605,7 +596,7 @@ Let's try solving the model. | |
```{code-cell} ipython3 | ||
:hide-output: false | ||
|
||
ae = Arellano_Economy() | ||
ae = create_arellano() | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
|
@@ -637,8 +628,9 @@ def simulate(model, T, v_c, v_d, q, B_star, key): | |
|
||
""" | ||
# Unpack elements of the model | ||
B_size, y_size = model.sizes() | ||
B_size, y_size = model.B_size, model.y_size | ||
B_grid, y_grid, P = model.B_grid, model.y_grid, model.P | ||
|
||
B0_idx = jnp.searchsorted(B_grid, 1e-10) # Index at which B is near zero | ||
|
||
# Set initial conditions | ||
|
@@ -695,18 +687,15 @@ def simulate(model, T, v_c, v_d, q, B_star, key): | |
|
||
Let’s start by trying to replicate the results obtained in {cite}`Are08`. | ||
|
||
In what follows, all results are computed using Arellano’s parameter values. | ||
|
||
The values can be seen in the `__init__` method of the `Arellano_Economy` | ||
shown above. | ||
In what follows, all results are computed using parameter values of `Arellano_Economy` created by `create_arellano`. | ||
|
||
For example, `r=0.017` matches the average quarterly rate on a 5 year US treasury over the period 1983–2001. | ||
|
||
Details on how to compute the figures are reported as solutions to the | ||
exercises. | ||
|
||
The first figure shows the bond price schedule and replicates Figure 3 of | ||
Arellano, where $ y_L $ and $ Y_H $ are particular below average and above average | ||
{cite}`Are08`, where $ y_L $ and $ Y_H $ are particular below average and above average | ||
values of output $ y $. | ||
|
||
![https://python-advanced.quantecon.org/_static/lecture_specific/arellano/arellano_bond_prices.png](https://python-advanced.quantecon.org/_static/lecture_specific/arellano/arellano_bond_prices.png) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we move these images to this repo as we will migrate the advanced series to theme-specific series soon? (CC @mmcky):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @HumphreyYang . I reckon @mmcky have already done that in PR #141 |
||
|
@@ -716,7 +705,7 @@ values of output $ y $. | |
- $ y_H $ is 5% above the mean of the $ y $ grid values | ||
|
||
|
||
The grid used to compute this figure was relatively fine (`y_grid_size, B_grid_size = 51, 251`), which explains the minor differences between this and | ||
The grid used to compute this figure was relatively fine (`y_size, B_size = 51, 251`), which explains the minor differences between this and | ||
Arrelano’s figure. | ||
|
||
The figure shows that | ||
|
@@ -766,7 +755,7 @@ Periods of relative stability are followed by sharp spikes in the discount rate | |
|
||
To the extent that you can, replicate the figures shown above | ||
|
||
- Use the parameter values listed as defaults in `Arellano_Economy`. | ||
- Use the parameter values listed as defaults in `Arellano_Economy` created by `create_arellano`. | ||
- The time series will of course vary depending on the shock draws. | ||
|
||
```{exercise-end} | ||
|
@@ -785,18 +774,18 @@ Compute the value function, policy and equilibrium prices | |
```{code-cell} ipython3 | ||
:hide-output: false | ||
|
||
ae = Arellano_Economy() | ||
ae = create_arellano() | ||
v_c, v_d, q, B_star = solve(ae) | ||
``` | ||
|
||
Compute the bond price schedule as seen in figure 3 of Arellano (2008) | ||
Compute the bond price schedule as seen in figure 3 of {cite}`Are08` | ||
|
||
```{code-cell} ipython3 | ||
:hide-output: false | ||
|
||
# Unpack some useful names | ||
B_grid, y_grid, P = ae.B_grid, ae.y_grid, ae.P | ||
B_size, y_size = ae.sizes() | ||
B_size, y_size = ae.B_size, ae.y_size | ||
r = ae.r | ||
|
||
# Create "Y High" and "Y Low" values as 5% devs from mean | ||
|
@@ -811,7 +800,7 @@ x = [] | |
q_low = [] | ||
q_high = [] | ||
for i, B in enumerate(B_grid): | ||
if -0.35 <= B <= 0: # To match fig 3 of Arellano | ||
if -0.35 <= B <= 0: # To match fig 3 of Arellano (2008) | ||
x.append(B) | ||
q_low.append(q[i, iy_low]) | ||
q_high.append(q[i, iy_high]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class (
namedtuple
) names should follow CamelCaseThere are some
Arellano_Economy
below in the text as well, so please update them together : )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! We should change the name in the original lecture too:
https://python-advanced.quantecon.org/arellano.html