Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
jstac committed Sep 21, 2023
1 parent 39cc01d commit 08ee04a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 38 deletions.
8 changes: 4 additions & 4 deletions lectures/opt_invest.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ B = jax.jit(B, static_argnums=(2,))
```

We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$,
which is defined as
which is defined as the vector

$$ r_\sigma(y, z) := r(y, z, \sigma(y, z)) $$

Expand Down Expand Up @@ -241,13 +241,13 @@ Next, we want to computes the lifetime value of following policy $\sigma$.

This lifetime value is a function $v_\sigma$ that satisfies

$$ v_\sigma(y, z) = r_\sigma(y, z) + \beta \sum_{z'} v_\sigma(\sigma(y, z), z') Q(z, z) $$
$$ v_\sigma(y, z) = r_\sigma(y, z) + \beta \sum_{z'} v_\sigma(\sigma(y, z), z') Q(z, z') $$

for $v$.
We wish to solve this equation for $v_\sigma$.

Suppose we define the linear operator $L_\sigma$ by

$$ (L_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z) $$
$$ (L_\sigma v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z') $$

With this notation, the problem is to solve for $v$ via

Expand Down
68 changes: 34 additions & 34 deletions lectures/opt_savings.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,14 @@ def B(v, constants, sizes, arrays):

## Operators

Now we define the policy operator $T_\sigma$

We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$,
which is defined as the vector


$$ r_\sigma(w, y) := r(w, y, \sigma(w, y)) $$



```{code-cell} ipython3
def compute_r_σ(σ, constants, sizes, arrays):
Expand All @@ -157,6 +164,8 @@ def compute_r_σ(σ, constants, sizes, arrays):
return r_σ
```

Now we define the policy operator $T_\sigma$

```{code-cell} ipython3
def T_σ(v, σ, constants, sizes, arrays):
"The σ-policy operator."
Expand Down Expand Up @@ -201,47 +210,36 @@ def get_greedy(v, constants, sizes, arrays):

The function below computes the value $v_\sigma$ of following policy $\sigma$.

The basic problem is to solve the linear system

$$ v(w,y ) = u(Rw + y - \sigma(w, y)) + β \sum_{y'} v(\sigma(w, y), y') Q(y, y) $$
This lifetime value is a function $v_\sigma$ that satisfies

for $v$.
$$ v_\sigma(w, y) = r_\sigma(w, y) + \beta \sum_{y'} v_\sigma(\sigma(w, y), y') Q(y, y') $$

It turns out to be helpful to rewrite this as
We wish to solve this equation for $v_\sigma$.

$$ v(w,y) = r(w, y, \sigma(w, y)) + β \sum_{w', y'} v(w', y') P_\sigma(w, y, w', y') $$
Suppose we define the linear operator $L_\sigma$ by

where $P_\sigma(w, y, w', y') = 1\{w' = \sigma(w, y)\} Q(y, y')$.
$$ (L_\sigma v)(w, y) = v(w, y) - \beta \sum_{y'} v(\sigma(w, y), y') Q(y, y') $$

We want to write this as $v = r_\sigma + P_\sigma v$ and then solve for $v$
With this notation, the problem is to solve for $v$ via

Note, however,
$$
(L_{\sigma} v)(w, y) = r_\sigma(w, y)
$$

* $v$ is a 2 index array, rather than a single vector.
* $P_\sigma$ has four indices rather than 2
In vector for this is $L_\sigma v = r_\sigma$, which tells us that the function
we seek is

The code below
$$ v_\sigma = L_\sigma^{-1} r_\sigma $$

1. reshapes $v$ and $r_\sigma$ to 1D arrays and $P_\sigma$ to a matrix
2. solves the linear system
3. converts back to multi-index arrays.
JAX allows us to solve linear systems defined in terms of operators; the first
step is to define the function $L_{\sigma}$.

```{code-cell} ipython3
def R_σ(v, σ, constants, sizes, arrays):
def L_σ(v, σ, constants, sizes, arrays):
"""
The value v_σ of a policy σ is defined as
v_σ = (I - β P_σ)^{-1} r_σ
Here we set up the linear map v -> R_σ v, where R_σ := I - β P_σ.
Here we set up the linear map v -> L_σ v, where
In the consumption problem, this map can be expressed as
(R_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
Defining the map as above works in a more intuitive multi-index setting
(e.g. working with v[i, j] rather than flattening v to a one-dimensional
array) and avoids instantiating the large matrix P_σ.
(L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
"""
Expand All @@ -262,9 +260,11 @@ def R_σ(v, σ, constants, sizes, arrays):
return v - β * jnp.sum(V * Q, axis=2)
```

Now we can define a function to compute $v_{\sigma}$

```{code-cell} ipython3
def get_value(σ, constants, sizes, arrays):
"Get the value v_σ of policy σ by inverting the linear map R_σ."
"Get the value v_σ of policy σ by inverting the linear map L_σ."
# Unpack
β, R, γ = constants
Expand All @@ -273,10 +273,10 @@ def get_value(σ, constants, sizes, arrays):
r_σ = compute_r_σ(σ, constants, sizes, arrays)
# Reduce R_σ to a function in v
partial_R_σ = lambda v: R_σ(v, σ, constants, sizes, arrays)
# Reduce L_σ to a function in v
partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)
return jax.scipy.sparse.linalg.bicgstab(partial_R_σ, r_σ)[0]
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
```

## JIT compiled versions
Expand All @@ -288,7 +288,7 @@ T = jax.jit(T, static_argnums=(2,))
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
get_value = jax.jit(get_value, static_argnums=(2,))
T_σ = jax.jit(T_σ, static_argnums=(3,))
R_σ = jax.jit(R_σ, static_argnums=(3,))
L_σ = jax.jit(L_σ, static_argnums=(3,))
```

## Solvers
Expand Down

0 comments on commit 08ee04a

Please sign in to comment.