Skip to content

Sharding Step Functions

Jonathan Bloedow edited this page Oct 10, 2024 · 12 revisions

Every timestep we loop over the population and do conditional updates. We have put a lot of work into getting these to execute as fast as possible. And when we can use EULAs, we get the benefit of only having to operate on a fraction of the population. But in Cholera, for example, we don't have any EULAs and so we have to do operations on 1e8 agents. One optimization we've only just started to look at is updating less frequently than once every timestep, or time sub-sampling. This maybe possible when dealing with operations that having longer timescales (weeks or months or years).

In a Cholera-like model, we can have tens of millions of susceptibility_timer values which need to be decremented each timestep. These timers are typically initialized to values measured in months or years. What if instead of looping over an array sized 1e8 every timestep to decrement each non-zero value by 1, we did this update every 7 or even 30 timesteps, and then decremented by 7 (or 30)?

That approach can work but it gets pretty "chunky". Is there a better way?

What if we "shard" the population (array) into a set of equal-sized blocks and process/update one block at a time? For example, if our update resolution is 7 days, we'll process 1/7th of the population (array) each timestep. After 7 timesteps we'll have processed the entire population. When we do process agents, we move them forward 7 timesteps at once. This gets us our 7x CPU benefit, but maintains a finer resolution as we have individuals reaching 0 every timestep.

Simplified Demo

Here's the result of doing this in a standalone demo/example.

image

We have an array sized 1 million. We initialize it with integers uniformly distributed from 150 to 210. We decrement these each timestep. The plot is the number of elements in the array == 0 at any time. Each timestep we reset 1% of the 0-valued elements to a random value in the original range.

We do this 3 ways:

  • Default
  • Coarsely/naively: every 30th timestep we move everything forward by 30.
  • Sharding
Total time spent in Default Decrement step: 0.4190 seconds
Total time spent in Sharded Decrement step: **0.0244 seconds**
Total time spent in Chunk Decrement step: 0.0154 seconds

Sharding gets us a big win on performance while losing almost nothing in our output.

2 Ways to Shard?

We can use contiguous shards or "strided" shards. The former seems like a much better use of memory, but could potentially not be randomized enough. If contiguous agents tend to be in the same node, then for any given node, the effect of sharding will be "chunky", I think. Also, numba needs a delta of 1, so numba would not be an option. Some experimentation is needed.

LASER

Immunity Timer

I explored the effects of applying this approach to one of the LASER step functions. I selected the immunity (susceptibility) countdown timer. The biggest time-user is transmission (~50%) and "ages" (~25%) but immunity timer is still time consuming and also has the advantage of being simple. The ages update also does a census for reporting and for transmission, so that will be explored 2nd.

We'll explore behavior and performance for both contiguous sharding and strided sharding. But first, let's establish the baseline.

Baseline

image

image

(35seconds)

Naive Time Subsampling

  • Delta=8
  • Down to 3 seconds from 35 (for do_susceptibility_decay)!

image

Gets the right basic form: image

But looking closer we see: image

Contiguous Sharding

  • Delta=8
  • 4 seconds from 35!

image

Gets the right basic form again: image

But looking closer we also see: image

It seems that contiguous sharding is still quite fast like naive time sub-sampling but has some of the same "chunkiness" as naive. The reason is that the agents in our nodes tend to be contiguous so in any given node, the agents all move forward together by delta, so we don't gain much.

Strided Sharding

  • Delta=8 (same)
  • do_susceptibility_decay=8.5s (2x slow as contiguous)

image

But now everything is smooth:

image

Zoomed in:

image

Observations

Strided sharding is a bit slower because it misses the benefit of zipping through purely contiguous data with an absolute minimum of cache misses, but it's still much faster than the original (8.5s vs 35s=75%) and maintains a nice smooth behavior.

Optimal Delta?

image

This plot shows the number of microseconds spent in the function as a function of the value of "delta" which is how many shards we use. As you can see, once we get to around 8, we're in diminishing returns.

Transmission

Time slicing the transmission code requires a bit more math than the timer functions. This is because we are:

  • Counting the number of infectious and susceptible populations (by node).
  • Doing delta days of transmission at once.

The plots below show both the SEIRW traces for the largest node and also the prevalence plots for the 5 largest nodes, for deltas from 1 through 7.

(Unfortunately I've lost the labeling of the value of delta for each plot, but since the plots are basically highly similar, hopefully the point still stands until I can regenerate properly labelled plots. Also, one of the 7 values of delta isn't here)

The total sim duration varied from 1:19 with delta=7 and 1:35 with delta=1. So there's an opportunity to shave off 15 seconds or about 15% of wall-clock time.

Prevalence Plots across Largest Nodes

laser_transmission_timesl-000 laser_transmission_timesl-004 laser_transmission_timesl-012 laser_transmission_timesl-016 laser_transmission_timesl-020 laser_transmission_timesl-024

SEIRW Plots for Largest Node

laser_transmission_timesl-026 laser_transmission_timesl-002 laser_transmission_timesl-006 laser_transmission_timesl-010 laser_transmission_timesl-018 laser_transmission_timesl-022

Clone this wiki locally