Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
jstac committed Sep 21, 2023
1 parent 317e12a commit 39cc01d
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions lectures/opt_invest.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,26 @@ $$ v_\sigma(y, z) = r_\sigma(y, z) + \beta \sum_{z'} v_\sigma(\sigma(y, z), z')

for $v$.

Suppose we define the linear operator $R_\sigma$ by
Suppose we define the linear operator $L_\sigma$ by

$$ (R_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$
$$ (L_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$

With this notation, the problem is to solve for $v$ via

$$
(R_{\sigma} v)(y, z) = r_\sigma(y, z)
(L_{\sigma} v)(y, z) = r_\sigma(y, z)
$$

In vector for this is $R_\sigma v = r_\sigma$, which tells us that the function
In vector for this is $L_\sigma v = r_\sigma$, which tells us that the function
we seek is

$$ v_\sigma = R_\sigma^{-1} r_\sigma $$
$$ v_\sigma = L_\sigma^{-1} r_\sigma $$

JAX allows us to solve linear systems defined in terms of operators; the first
step is to define the function $R_{\sigma}$.
step is to define the function $L_{\sigma}$.

```{code-cell} ipython3
def R_σ(v, σ, constants, sizes, arrays):
def L_σ(v, σ, constants, sizes, arrays):
β, a_0, a_1, γ, c = constants
y_size, z_size = sizes
Expand All @@ -282,7 +282,7 @@ def R_σ(v, σ, constants, sizes, arrays):
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
return v - β * jnp.sum(V * Q, axis=2)
R_σ = jax.jit(R_σ, static_argnums=(3,))
L_σ = jax.jit(L_σ, static_argnums=(3,))
```

Now we can define a function to compute $v_{\sigma}$
Expand All @@ -298,10 +298,10 @@ def get_value(σ, constants, sizes, arrays):
r_σ = compute_r_σ(σ, constants, sizes, arrays)
# Reduce R_σ to a function in v
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
# Reduce L_σ to a function in v
partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)
return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
get_value = jax.jit(get_value, static_argnums=(2,))
```
Expand Down

0 comments on commit 39cc01d

Please sign in to comment.