From 39cc01da64ecbaa62c949cd98757c3d17fc09b3f Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Thu, 21 Sep 2023 17:53:48 -0500 Subject: [PATCH] misc --- lectures/opt_invest.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lectures/opt_invest.md b/lectures/opt_invest.md index d7220140..a2fc389c 100644 --- a/lectures/opt_invest.md +++ b/lectures/opt_invest.md @@ -245,26 +245,26 @@ $$ v_\sigma(y, z) = r_\sigma(y, z) + \beta \sum_{z'} v_\sigma(\sigma(y, z), z') for $v$. -Suppose we define the linear operator $R_\sigma$ by +Suppose we define the linear operator $L_\sigma$ by -$$ (R_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$ +$$ (L_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$ With this notation, the problem is to solve for $v$ via $$ - (R_{\sigma} v)(y, z) = r_\sigma(y, z) + (L_{\sigma} v)(y, z) = r_\sigma(y, z) $$ -In vector for this is $R_\sigma v = r_\sigma$, which tells us that the function +In vector for this is $L_\sigma v = r_\sigma$, which tells us that the function we seek is -$$ v_\sigma = R_\sigma^{-1} r_\sigma $$ +$$ v_\sigma = L_\sigma^{-1} r_\sigma $$ JAX allows us to solve linear systems defined in terms of operators; the first -step is to define the function $R_{\sigma}$. +step is to define the function $L_{\sigma}$. ```{code-cell} ipython3 -def R_σ(v, σ, constants, sizes, arrays): +def L_σ(v, σ, constants, sizes, arrays): β, a_0, a_1, γ, c = constants y_size, z_size = sizes @@ -282,7 +282,7 @@ def R_σ(v, σ, constants, sizes, arrays): # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp] return v - β * jnp.sum(V * Q, axis=2) -R_σ = jax.jit(R_σ, static_argnums=(3,)) +L_σ = jax.jit(L_σ, static_argnums=(3,)) ``` Now we can define a function to compute $v_{\sigma}$ @@ -298,10 +298,10 @@ def get_value(σ, constants, sizes, arrays): r_σ = compute_r_σ(σ, constants, sizes, arrays) - # Reduce R_σ to a function in v - partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays) + # Reduce L_σ to a function in v + partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays) - return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0] + return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0] get_value = jax.jit(get_value, static_argnums=(2,)) ```