diff --git a/lectures/inventory_dynamics.md b/lectures/inventory_dynamics.md index 96e1f7e6..0fb8d5ac 100644 --- a/lectures/inventory_dynamics.md +++ b/lectures/inventory_dynamics.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.5 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -29,7 +29,24 @@ kernelspec: ## Overview -This lecture explores JAX implementations of the exercises in the lecture on [inventory dynamics](https://python.quantecon.org/inventory_dynamics.html). +This lecture explores the inventory dynamics of a firm using so-called s-S inventory control. + +Loosely speaking, this means that the firm + +* waits until inventory falls below some value $s$ +* and then restocks with a bulk order of $S$ units (or, in some models, restocks up to level $S$). + +We will be interested in the distribution of the associated Markov process, +which can be thought of as cross-sectional distributions of inventory levels +across a large number of firms, all of which + +1. evolve independently and +1. have the same dynamics. + +Note that we also studied this model in a [separate +lecture](https://python.quantecon.org/inventory_dynamics.html), using Numba. + +Here we study the same problem using JAX. We will use the following imports: @@ -42,7 +59,7 @@ from jax import random, lax from collections import namedtuple ``` -Let's check the GPU we are running +Here's a description of our GPU: ```{code-cell} ipython3 !nvidia-smi @@ -54,7 +71,8 @@ Consider a firm with inventory $X_t$. The firm waits until $X_t \leq s$ and then restocks up to $S$ units. -It faces stochastic demand $\{ D_t \}$, which we assume is IID. +It faces stochastic demand $\{ D_t \}$, which we assume is IID across time and +firms. With notation $a^+ := \max\{a, 0\}$, inventory dynamics can be written as @@ -67,8 +85,6 @@ X_{t+1} = \end{cases} $$ -(See our earlier [lecture on inventory dynamics](https://python.quantecon.org/inventory_dynamics.html) for background and motivation.) - In what follows, we will assume that each $D_t$ is lognormal, so that $$ @@ -81,209 +97,292 @@ and standard normal. Here's a `namedtuple` that stores parameters. ```{code-cell} ipython3 -Firm = namedtuple('Firm', ['s', 'S', 'mu', 'sigma']) +Parameters = namedtuple('Parameters', ['s', 'S', 'μ', 'σ']) -firm = Firm(s=10, S=100, mu=1.0, sigma=0.5) +# Create a default instance +params = Parameters(s=10, S=100, μ=1.0, σ=0.5) ``` -## Example 1: marginal distributions -Now let’s look at the marginal distribution $\psi_T$ of $X_T$ for some fixed -$T$. -We can approximate the distribution using a [kernel density estimator](https://en.wikipedia.org/wiki/Kernel_density_estimation). -Kernel density estimators can be thought of as smoothed histograms. +## Cross-sectional distributions -We will use a kernel density estimator from [scikit-learn](https://scikit-learn.org/stable/). +Now let’s look at the marginal distribution $\psi_T$ of $X_T$ for some fixed $T$. -Here is an example of using kernel density estimators and plotting the result +The probability distribution $\psi_T$ is the time $T$ distribution of firm +inventory levels implied by the model. -```{code-cell} ipython3 -from sklearn.neighbors import KernelDensity +We will approximate this distribution by -def plot_kde(sample, ax, label=''): - xmin, xmax = 0.9 * min(sample), 1.1 * max(sample) - xgrid = np.linspace(xmin, xmax, 200) - kde = KernelDensity(kernel='gaussian').fit(sample[:, None]) - log_dens = kde.score_samples(xgrid[:, None]) +1. fixing $n$ to be some large number, indicating the number of firms in the + simulation, +1. fixing $T$, the time period we are interested in, +1. generating $n$ independent draws from some fixed distribution $\psi_0$ that gives the + initial cross-section of inventories for the $n$ firms, and +1. shifting this distribution forward in time $T$ periods, updating each firm + $T$ times via the dynamics described above (independent of other firms). - ax.plot(xgrid, np.exp(log_dens), label=label) +We will then visualize $\psi_T$ by histogramming the cross-section. -# Generate simulated data -np.random.seed(42) -sample_1 = np.random.normal(0, 2, size=10_000) -sample_2 = np.random.gamma(2, 2, size=10_000) +We will use the following code to update the cross-section of firms by one period. -# Create a plot -fig, ax = plt.subplots() +```{code-cell} ipython3 +@jax.jit +def update_cross_section(params, X_vec, D): + """ + Update by one period a cross-section of firms with inventory levels given by + X_vec, given the vector of demand shocks in D. -# Plot the samples -ax.hist(sample_1, alpha=0.2, density=True, bins=50) -ax.hist(sample_2, alpha=0.2, density=True, bins=50) - -# Plot the KDE for each sample -plot_kde(sample_1, ax, label=r'KDE over $X \sim N(0, 2)$') -plot_kde(sample_2, ax, label=r'KDE over $X \sim Gamma(0, 2)$') -ax.set_xlabel('value') -ax.set_ylabel('density') -ax.set_xlim([-5, 10]) -ax.set_title('KDE of Simulated Normal and Gamma Data') -ax.legend() -plt.show() -``` + * D[i] is the demand shock for firm i with current inventory X_vec[i] -This model for inventory dynamics is asymptotically stationary, with a unique -stationary distribution. + """ + # Unpack + s, S = params.s, params.S + # Restock if the inventory is below the threshold + X_new = jnp.where(X_vec <= s, + jnp.maximum(S - D, 0), jnp.maximum(X_vec - D, 0)) + return X_new +``` -In particular, the sequence of marginal distributions $\{\psi_t\}$ -converges to a unique limiting distribution that does not depend on -initial conditions. -Although we will not prove this here, we can investigate it using simulation. -We can generate and plot the sequence $\{\psi_t\}$ at times -$t = 10, 50, 250, 500, 750$ based on the kernel density estimator. -We will see convergence, in the sense that differences between successive -distributions are getting smaller. +### For loop version -Here is one realization of the process in JAX using `for` loop +Now we provide code to compute the cross-sectional distribution $\psi_T$ given some +initial distribution $\psi_0$ and a positive integer $T$. -```{code-cell} ipython3 -# Define a jit-compiled function to update X and key -@jax.jit -def update_X(X, firm, D): - # Restock if the inventory is below the threshold - res = jnp.where(X <= firm.s, - jnp.maximum(firm.S - D, 0), - jnp.maximum(X - D, 0)) - return res +In this code we use an ordinary Python `for` loop to step forward through time +While Python loops are slow, this approach is reasonable here because +efficiency of outer loops has far less influence on runtime than efficiency of inner loops. -def shift_firms_forward(x_init, firm, sample_dates, - key, num_firms=50_000, sim_length=750): +(Below we will squeeze out more speed by compiling the outer loop as well as the +update rule.) - X = res = jnp.full((num_firms, ), x_init) +In the code below, the initial distribution $\psi_0$ takes all firms to have +initial inventory `x_init`. - # Use for loop to update X and collect samples - for i in range(sim_length): +```{code-cell} ipython3 +def compute_cross_section(params, x_init, T, key, num_firms=50_000): + # Set up initial distribution + X_vec = jnp.full((num_firms, ), x_init) + # Loop + for i in range(T): Z = random.normal(key, shape=(num_firms, )) - D = jnp.exp(firm.mu + firm.sigma * Z) + D = jnp.exp(params.μ + params.σ * Z) - X = update_X(X, firm, D) + X_vec = update_cross_section(params, X_vec, D) _, key = random.split(key) - # draw a sample at the sample dates - if (i+1 in sample_dates): - res = jnp.vstack((res, X)) - - return res[1:] + return X_vec ``` +We'll use the following specification + ```{code-cell} ipython3 x_init = 50 -num_firms = 50_000 -sample_dates = 10, 50, 250, 500, 750 +T = 500 +# Initialize random number generator key = random.PRNGKey(10) +``` -fig, ax = plt.subplots() +Let's look at the timing. -%time X = shift_firms_forward(x_init, firm, \ - sample_dates, key).block_until_ready() +```{code-cell} ipython3 +%time X_vec = compute_cross_section(params, \ + x_init, T, key).block_until_ready() +``` -for i, date in enumerate(sample_dates): - plot_kde(X[i, :], ax, label=f't = {date}') +```{code-cell} ipython3 +%time X_vec = compute_cross_section(params, \ + x_init, T, key).block_until_ready() +``` + +Here's a histogram of inventory levels at time $T$. +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.hist(X_vec, bins=50, + density=True, + histtype='step', + label=f'cross-section when $t = {T}$') ax.set_xlabel('inventory') ax.set_ylabel('probability') ax.legend() plt.show() ``` -Note that we did not JIT-compile the outer loop, since - -1. `jit` compilation of the `for` loop can be very time consuming and -1. compiling outer loops only leads to minor speed gains. - - -### Alternative implementation with `lax.scan` +### Compiling the outer loop -An alternative to the `for` loop implementation is `lax.scan`. +Now let's see if we can gain some speed by compiling the outer loop, which steps +through the time dimension. -Here is an example of the same function in `lax.scan` +We will do this using `jax.jit` and a `fori_loop`, which is a compiler-ready version of a `for` loop provided by JAX. ```{code-cell} ipython3 -@jax.jit -def shift_firms_forward(x_init, firm, key, - num_firms=50_000, sim_length=750): +def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000): - s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma + s, S, μ, σ = params.s, params.S, params.μ, params.σ X = jnp.full((num_firms, ), x_init) - Z = random.normal(key, shape=(sim_length, num_firms)) - D = jnp.exp(mu + sigma * Z) # Define the function for each update - def update_X(X, D): - res = jnp.where(X <= s, + def fori_update(t, inputs): + # Unpack + X, key = inputs + # Draw shocks using key + Z = random.normal(key, shape=(num_firms,)) + D = jnp.exp(μ + σ * Z) + # Update X + X = jnp.where(X <= s, jnp.maximum(S - D, 0), jnp.maximum(X - D, 0)) - return res, res + # Refresh the key + key, subkey = random.split(key) + return X, subkey - # Use lax.scan to perform the calculations on all states - _, X_final = lax.scan(update_X, X, D) + # Loop t from 0 to T, applying fori_update each time. + # The initial condition for fori_update is (X, key). + X, key = lax.fori_loop(0, T, fori_update, (X, key)) - return X_final + return X + +# Compile taking T and num_firms as static (changes trigger recompile) +compute_cross_section_fori = jax.jit( + compute_cross_section_fori, static_argnums=(2, 4)) ``` -The benefit of the `lax.scan` implementation is that we compile the whole -operation. +Let's see how fast this runs with compile time. + +```{code-cell} ipython3 +%time X_vec = compute_cross_section_fori(params, \ + x_init, T, key).block_until_ready() +``` + +And let's see how fast it runs without compile time. + +```{code-cell} ipython3 +%time X_vec = compute_cross_section_fori(params, \ + x_init, T, key).block_until_ready() +``` + +Compared to the original version with a pure Python outer loop, we have +produced a nontrivial speed gain. + + +This is due to the fact that we have compiled the whole operation. -The disadvantages are that -1. as mentioned above, there are only limited speed gains in accelerating outer loops, -2. `lax.scan` has a more complicated syntax, and, most importantly, -3. the `lax.scan` implementation consumes far more memory, as we need to have to - store large matrices of random draws -Let's call the code to generate a cross-section that is in approximate -equilibrium. +### Further vectorization + +For relatively small problems, we can make this code run even faster by generating +all random variables at once. + +This improves efficiency because we are taking more operations out of the loop. ```{code-cell} ipython3 -fig, ax = plt.subplots() +def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000): -%time X = shift_firms_forward(x_init, firm, key).block_until_ready() + s, S, μ, σ = params.s, params.S, params.μ, params.σ + X = jnp.full((num_firms, ), x_init) + Z = random.normal(key, shape=(T, num_firms)) + D = jnp.exp(μ + σ * Z) -for date in sample_dates: - plot_kde(X[date, :], ax, label=f't = {date}') + def update_cross_section(i, X): + X = jnp.where(X <= s, + jnp.maximum(S - D[i, :], 0), + jnp.maximum(X - D[i, :], 0)) + return X -ax.set_xlabel('inventory') -ax.set_ylabel('probability') -ax.legend() -plt.show() + X = lax.fori_loop(0, T, update_cross_section, X) + + return X + +# Compile taking T and num_firms as static (changes trigger recompile) +compute_cross_section_fori = jax.jit( + compute_cross_section_fori, static_argnums=(2, 4)) +``` + +Let's test it with compile time included. + +```{code-cell} ipython3 +%time X_vec = compute_cross_section_fori(params, \ + x_init, T, key).block_until_ready() +``` + +Let's run again to eliminate compile time. + +```{code-cell} ipython3 +%time X_vec = compute_cross_section_fori(params, \ + x_init, T, key).block_until_ready() ``` -Notice that by $t=500$ or $t=750$ the densities are barely -changing. +On one hand, this version is faster than the previous one, where random variables were +generated inside the loop. + +On the other hand, this implementation consumes far more memory, as we need to +store large arrays of random draws. + +The high memory consumption becomes problematic for large problems. + + -We have reached a reasonable approximation of the stationary density. +## Distribution dynamics -You can test a few more initial conditions to show that they do not affect -long-run outcomes. +Next let's take a look at how the distribution sequence evolves over time. -For example, try rerunning the code above with all firms starting at -$X_0 = 20$ +We will go back to using ordinary Python `for` loops. + +Here is code that repeatedly shifts the cross-section forward while +recording the cross-section at the dates in `sample_dates`. ```{code-cell} ipython3 -x_init = 20.0 +def shift_forward_and_sample(x_init, params, sample_dates, + key, num_firms=50_000, sim_length=750): -fig, ax = plt.subplots() + X = res = jnp.full((num_firms, ), x_init) -%time X = shift_firms_forward(x_init, firm, key).block_until_ready() + # Use for loop to update X and collect samples + for i in range(sim_length): + Z = random.normal(key, shape=(num_firms, )) + D = jnp.exp(params.μ + params.σ * Z) + + X = update_cross_section(params, X, D) + _, key = random.split(key) + + # draw a sample at the sample dates + if (i+1 in sample_dates): + res = jnp.vstack((res, X)) + + return res[1:] +``` + +Let's test it + +```{code-cell} ipython3 +x_init = 50 +num_firms = 10_000 +sample_dates = 10, 50, 250, 500, 750 +key = random.PRNGKey(10) -for date in sample_dates: - plot_kde(X[date, :], ax, label=f't = {date}') + +%time X = shift_forward_and_sample(x_init, params, \ + sample_dates, key).block_until_ready() +``` + +Let's plot the output. + +```{code-cell} ipython3 +fig, ax = plt.subplots() + +for i, date in enumerate(sample_dates): + ax.hist(X[i, :], bins=50, + density=True, + histtype='step', + label=f'cross-section when $t = {date}$') ax.set_xlabel('inventory') ax.set_ylabel('probability') @@ -291,32 +390,53 @@ ax.legend() plt.show() ``` -## Example 2: restock frequency +This model for inventory dynamics is asymptotically stationary, with a unique +stationary distribution. + +In particular, the sequence of marginal distributions $\{\psi_t\}$ +converges to a unique limiting distribution that does not depend on +initial conditions. + +Although we will not prove this here, we can see it in the simulation above. + +By $t=500$ or $t=750$ the distributions are barely changing. + +If you test a few different initial conditions, you will see that they do not affect long-run outcomes. -Let's go through another example where we calculate the probability of firms -having restocks. -Specifically we set the starting stock level to 70 ($X_0 = 70$), as we calculate -the proportion of firms that need to order twice or more in the first 50 -periods. -You will need a large sample size to get an accurate reading. -Again, we start with an easier `for` loop implementation + +## Restock frequency + +As an exercise, let's study the probability that firms need to restock over a given time period. + +In the exercise, we will + +* set the starting stock level to $X_0 = 70$ and +* calculate the proportion of firms that need to order twice or more in the first 50 periods. + +This proportion approximates the probability of the event when the sample size +is large. + + +### For loop version + +We start with an easier `for` loop implementation ```{code-cell} ipython3 # Define a jitted function for each update @jax.jit -def update_stock(n_restock, X, firm, D): - n_restock = jnp.where(X <= firm.s, +def update_stock(n_restock, X, params, D): + n_restock = jnp.where(X <= params.s, n_restock + 1, n_restock) - X = jnp.where(X <= firm.s, - jnp.maximum(firm.S - D, 0), + X = jnp.where(X <= params.s, + jnp.maximum(params.S - D, 0), jnp.maximum(X - D, 0)) return n_restock, X, key -def compute_freq(firm, key, +def compute_freq(params, key, x_init=70, sim_length=50, num_firms=1_000_000): @@ -330,9 +450,9 @@ def compute_freq(firm, key, # Use a for loop to perform the calculations on all states for i in range(sim_length): Z = random.normal(key, shape=(num_firms, )) - D = jnp.exp(firm.mu + firm.sigma * Z) + D = jnp.exp(params.μ + params.σ * Z) n_restock, X, key = update_stock( - n_restock, X, firm, D) + n_restock, X, params, D) key = random.fold_in(key, i) return jnp.mean(n_restock > 1, axis=0) @@ -340,33 +460,45 @@ def compute_freq(firm, key, ```{code-cell} ipython3 key = random.PRNGKey(27) -%time freq = compute_freq(firm, key).block_until_ready() +%time freq = compute_freq(params, key).block_until_ready() print(f"Frequency of at least two stock outs = {freq}") ``` -### Alternative implementation with `lax.fori_loop` +```{exercise-start} +:label: inventory_dynamics_ex1 +``` -Now let's write a `lax.fori_loop` version that JIT compiles the whole function +Write a `fori_loop` version of the last function. See if you can increase the +speed while generating a similar answer. + +```{exercise-end} +``` + +```{solution-start} inventory_dynamics_ex1 +:class: dropdown +``` + +Here is a `lax.fori_loop` version that JIT compiles the whole function ```{code-cell} ipython3 @jax.jit -def compute_freq(firm, key, +def compute_freq(params, key, x_init=70, sim_length=50, num_firms=1_000_000): - s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma + s, S, μ, σ = params.s, params.S, params.μ, params.σ # Prepare initial arrays X = jnp.full((num_firms, ), x_init) Z = random.normal(key, shape=(sim_length, num_firms)) - D = jnp.exp(mu + sigma * Z) + D = jnp.exp(μ + σ * Z) # Stack the restock counter on top of the inventory restock_count = jnp.zeros((num_firms, )) Xs = (X, restock_count) # Define the function for each update - def update_X(i, Xs): + def update_cross_section(i, Xs): # Separate the inventory and restock counter x, restock_count = Xs[0], Xs[1] restock_count = jnp.where(x <= s, @@ -380,7 +512,7 @@ def compute_freq(firm, key, return Xs # Use lax.fori_loop to perform the calculations on all states - X_final = lax.fori_loop(0, sim_length, update_X, Xs) + X_final = lax.fori_loop(0, sim_length, update_cross_section, Xs) return jnp.mean(X_final[1] > 1) ``` @@ -388,6 +520,11 @@ def compute_freq(firm, key, Note the time the routine takes to run, as well as the output ```{code-cell} ipython3 -%time freq = compute_freq(firm, key).block_until_ready() +%time freq = compute_freq(params, key).block_until_ready() +%time freq = compute_freq(params, key).block_until_ready() + print(f"Frequency of at least two stock outs = {freq}") ``` + +```{solution-end} +```