diff --git a/lectures/opt_invest.md b/lectures/opt_invest.md index 625cd73e..5efac033 100644 --- a/lectures/opt_invest.md +++ b/lectures/opt_invest.md @@ -25,8 +25,7 @@ We require the following library to be installed. !pip install --upgrade quantecon ``` -A monopolist faces inverse demand -curve +We study a monopolist who faces inverse demand curve $$ P_t = a_0 - a_1 Y_t + Z_t, @@ -38,7 +37,7 @@ where * $Y_t$ is output and * $Z_t$ is a demand shock. -We assume that $Z_t$ is a discretized AR(1) process. +We assume that $Z_t$ is a discretized AR(1) process, specified below. Current profits are @@ -116,10 +115,10 @@ def create_investment_model( Let's re-write the vectorized version of the right-hand side of the -Bellman equation (before maximization), which is a 3D array representing: +Bellman equation (before maximization), which is a 3D array representing $$ - B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z') + B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z') $$ for all $(y, z, y')$. @@ -154,8 +153,10 @@ def B(v, constants, sizes, arrays): B = jax.jit(B, static_argnums=(2,)) ``` +We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$, +which is defined as -Define a function to compute the current rewards given policy $\sigma$. +$$ r_\sigma(y, z) := r(y, z, \sigma(y, z)) $$ ```{code-cell} ipython3 def compute_r_σ(σ, constants, sizes, arrays): @@ -238,44 +239,29 @@ T_σ = jax.jit(T_σ, static_argnums=(3,)) Next, we want to computes the lifetime value of following policy $\sigma$. -The basic problem is to solve the linear system +This lifetime value is a function $v_\sigma$ that satisfies -$$ v(y, z) = r(y, z, \sigma(y, z)) + \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$ +$$ v_\sigma(y, z) = r_\sigma(y, z) + \beta \sum_{z'} v_\sigma(\sigma(y, z), z') Q(z, z) $$ for $v$. -It turns out to be helpful to rewrite this as +Suppose we define the linear operator $R_\sigma$ by -$$ v(y, z) = r(y, z, \sigma(y, z)) + \beta \sum_{y', z'} v(y', z') P_\sigma(y, z, y', z') $$ +$$ (R_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$ -where $P_\sigma(y, z, y', z') = 1\{y' = \sigma(y, z)\} Q(z, z')$. - -We want to write this as $v = r_\sigma + \beta P_\sigma v$ and then solve for $v$ - -Note, however, that $v$ is a multi-index array, rather than a vector. - - -The value $v_{\sigma}$ of a policy $\sigma$ is defined as +With this notation, the problem is to solve for $v$ via $$ - v_{\sigma} = (I - \beta P_{\sigma})^{-1} r_{\sigma} + (R_{\sigma} v)(y, z) = r_\sigma(y, z) $$ -Here we set up the linear map $v \mapsto R_{\sigma} v$, - -where $R_{\sigma} := I - \beta P_{\sigma}$ - -In the investment problem, this map can be expressed as - -$$ - (R_{\sigma} v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z') -$$ +In vector for this is $R_\sigma v = r_\sigma$, which tells us that the function +we seek is -Defining the map as above works in a more intuitive multi-index setting -(e.g. working with $v[i, j]$ rather than flattening v to a one-dimensional -array) and avoids instantiating the large matrix $P_{\sigma}$. +$$ v_\sigma = R_\sigma^{-1} r_\sigma $$ -Let's define the function $R_{\sigma}$. +JAX allows us to solve linear systems defined in terms of operators; the first +step is to define the function $R_{\sigma}$. ```{code-cell} ipython3 def R_σ(v, σ, constants, sizes, arrays): @@ -299,9 +285,8 @@ def R_σ(v, σ, constants, sizes, arrays): R_σ = jax.jit(R_σ, static_argnums=(3,)) ``` +Now we can define a function to compute $v_{\sigma}$ -Define a function to get the value $v_{\sigma}$ of policy -$\sigma$ by inverting the linear map $R_{\sigma}$. ```{code-cell} ipython3 def get_value(σ, constants, sizes, arrays): @@ -322,7 +307,7 @@ get_value = jax.jit(get_value, static_argnums=(2,)) ``` -Now we define the solvers, which implement VFI, HPI and OPI. +Finally, we introduce the solvers that implement VFI, HPI and OPI. ```{code-cell} ipython3 :load: _static/lecture_specific/vfi.py