-
Notifications
You must be signed in to change notification settings - Fork 5
Sharding Step Functions
Every timestep we loop over the population and do conditional updates. We have put a lot of work into getting these to execute as fast as possible. And when we can use EULAs, we get the benefit of only having to operate on a fraction of the population. But in Cholera, for example, we don't have any EULAs and so we have to do operations on 1e8 agents. One optimization we've only just started to look at is updating less frequently than once every timestep, or time sub-sampling. This maybe possible when dealing with operations that having longer timescales (weeks or months or years).
In a Cholera-like model, we can have tens of millions of susceptibility_timer values which need to be decremented each timestep. These timers are typically initialized to values measured in months or years. What if instead of looping over an array sized 1e8 every timestep to decrement each non-zero value by 1, we did this update every 7 or even 30 timesteps, and then decremented by 7 (or 30)?
That approach can work but it gets pretty "chunky". Is there a better way?
What if we "shard" the population (array) into a set of equal-sized blocks and process/update one block at a time? For example, if our update resolution is 7 days, we'll process 1/7th of the population (array) each timestep. After 7 timesteps we'll have processed the entire population. When we do process agents, we move them forward 7 timesteps at once. This gets us our 7x CPU benefit, but maintains a finer resolution as we have individuals reaching 0 every timestep.
Here's the result of doing this in a standalone demo/example.
We have an array sized 1 million. We initialize it with integers uniformly distributed from 150 to 210. We decrement these each timestep. The plot is the number of elements in the array == 0 at any time. Each timestep we reset 1% of the 0-valued elements to a random value in the original range.
We do this 3 ways:
- Default
- Coarsely/naively: every 30th timestep we move everything forward by 30.
- Sharding
Total time spent in Default Decrement step: 0.4190 seconds
Total time spent in Sharded Decrement step: **0.0244 seconds**
Total time spent in Chunk Decrement step: 0.0154 seconds
Sharding gets us a big win on performance while losing almost nothing in our output.
We can use contiguous shards or "strided" shards. The former seems like a much better use of memory, but could potentially not be randomized enough. If contiguous agents tend to be in the same node, then for any given node, the effect of sharding will be "chunky", I think. Also, numba needs a delta of 1, so numba would not be an option. Some experimentation is needed.
I explored the effects of applying this approach to one of the LASER step functions. I selected the immunity (susceptibility) countdown timer. The biggest time-user is transmission (~50%) and "ages" (~25%) but immunity timer is still time consuming and also has the advantage of being simple. The ages update also does a census for reporting and for transmission, so that will be explored 2nd.
We'll explore behavior and performance for both contiguous sharding and strided sharding. But first, let's establish the baseline.
(35seconds)
- Delta=8
- Down to 3 seconds from 35 (for do_susceptibility_decay)!
Gets the right basic form:
But looking closer we see:
- Delta=8
- 4 seconds from 35!
Gets the right basic form again:
But looking closer we also see:
It seems that contiguous sharding is still quite fast like naive time sub-sampling but has some of the same "chunkiness" as naive. The reason is that the agents in our nodes tend to be contiguous so in any given node, the agents all move forward together by delta, so we don't gain much.
- Delta=8 (same)
- do_susceptibility_decay=8.5s (2x slow as contiguous)
But now everything is smooth:
Zoomed in:
Strided sharding is a bit slower because it misses the benefit of zipping through purely contiguous data with an absolute minimum of cache misses, but it's still much faster than the original (8.5s vs 35s=75%) and maintains a nice smooth behavior.
This plot shows the number of microseconds spent in the function as a function of the value of "delta" which is how many shards we use. As you can see, once we get to around 8, we're in diminishing returns.
Time slicing the transmission code requires a bit more math than the timer functions. This is because we are:
- Counting the number of infectious and susceptible populations (by node).
- Doing
delta
days of transmission at once.
The plots below show both the SEIRW traces for the largest node and also the prevalence plots for the 5 largest nodes, for deltas from 1 through 7.
(Unfortunately I've lost the labeling of the value of delta for each plot, but since the plots are basically highly similar, hopefully the point still stands until I can regenerate properly labelled plots. Also, one of the 7 values of delta isn't here)
The total sim duration varied from 1:19 with delta=7 and 1:35 with delta=1. So there's an opportunity to shave off 15 seconds or about 15% of wall-clock time.