-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using coiled/dask to parallelize making/storing task graphs #229
Comments
@gjoseph92 please take a look when you're back on Monday. |
Dask generally isn't designed to build the task graphs in parallel. Though it's possible to get this work, I think it might indicate a broader problem and you should take a step back and look at different approaches. If your graphs are so large that you want a cluster just to build them, I think it's unlikely you'd be able to get those graphs to actually run effectively once they're built. If graph construction is a bottleneck, I'd instead look at how you're using dask and think about ways to avoid building such large graphs, or so many graphs. Or, it's possible that what you're doing is reasonable, but you're encountering bugs in dask or xarray that make things much, much slower than they should be: Profiling your code that constructs the graphs will be a good place to start. I recommend using py-spy and the speedscope output format. We'll want to see where the bottlenecks are, and why building the graphs is taking so long. After If you can give an overview / pseudocode of what you're doing, and broadly, what you're trying to accomplish, that would help too. |
Generally, my problems seem to be coming where I'm trying to parallelize embarrassingly parallel parts of my data intake pipeline. For example, generating a stackstac xarray object for a point in space from a series of geotiffs. Each point is independent. I search the STAC api for the images, feed their metadata into stackstac, and get an xarray object with a dask graph. For any individual point the graph is not unmanageably large or slow to generate (I'll measure this). However, for 100+ points it becomes a client-side bottleneck. Naturally I would parallelize this. Perhaps using dask to achieve that is not the best approach? Thanks for the profiling recommendations - I will take a look and get back to you. Slightly off-topic: Does stackstac read asynchronously? Later in the pipeline those reads are a bottleneck. |
So from a series of GeoTIFFs, you want to extract data for a single pixel? Why multiple GeoTIFFs, then? A point will belong to at most one GeoTIFF, so using multiple seems unnecessary. Unless you're talking about wanting to get a timeseries for that point ("pixel skewer")? If you are looking to get pixel timeseries for many points, I'd want to know about the spatial distribution of the points (are they often near each other / do multiple fall within the same GeoTIFF), and whether every point uses a different set of GeoTIFFs, or always the same set.
How are you determining that this is the bottleneck? (Versus searching the STAC API, etc.) Again, profiling will be helpful here.
Not sure what you mean there. stackstac creates dask tasks to do the reads, which are executed in parallel by dask when you later call Regardless, IO is definitely going to be the bottleneck in any GeoTIFF-based workflow. |
I'm pulling a time-series from multiple images. The spatial distribution is far apart (not overlapping geotiffs, in general), and therefore each point uses a different set of geotiffs. I determine that it's a bottleneck based on logs from prefect, as well as container cpu/memory tracking. Searching the STAC API is something I parallelize on the cluster, and it completes in a few seconds. Whereas my logs show that the stack() operation (looped) takes hours. After selecting the point-time-series I join them together, to get one final graph. The actual execution of the reads & compute on this graph takes <30min. (I'll provide a more detailed profile later.) What I mean by asynchronous is: since I'm extracting a time-series, each dask chunk includes multiple images (to avoid even bigger graphs). Within a task/chunk I'm reading from multiple geotiffs (small piece of the image), which could either block on each request, or execute asynchronously. So my single dask task could be a lot faster if it does all the reads inside it asynchronously. (I suppose analogously if each dask task only reads one image (a very small chunk of the image), you'd have to have loads of threads or do something asynchronous.) FYI these are cloud-optimized geotiffs. Agreed I/O is my primary concern - I'm just comparing this part of the pipeline to reads from zarr for example (where the zarr chunking is similar to geotiff). |
Quick addendum: The times mentioned are on an ECS agent, and different from local docker tests. That probably points to another cause for those super-long times. I'll try to track that down in the meantime and profile both. More consistent is the memory usage, which is another reason I was interested in putting those objects on the cluster (when the client is small). |
That is an incredibly long time. Again, a profile here would be extremely helpful. I still wouldn't be surprised if dask/dask#8570 was also part of the issue. Unless you're explicitly setting
When multiple assets are in a single chunk, stackstac reads them serially: https://github.com/gjoseph92/stackstac/blob/d53f3dc94e49b95b950201f9e53539a92b1458b6/stackstac/to_dask.py#L180-L182 Chunks are still processed in parallel, so you'll always have |
For this sort of operation (get timeseries of 100,000 points), it should be as simple as all_items = pystac_client.search(...)
huge_stack = stackstac.stack(all_items, ...)
xs = xr.DataArray(lons, dims=['point'])
ys = xr.DataArray(lats, dims=['point'])
ts = huge_stack.sel(x=xs, y=ys, method="nearest") Basically you make a massive array of all the data covering the entire world, then index out just the parts you need. Theoretically this should be fast, but because of dask/dask#8570, it's so slow that it's not possible. So instead, sadly, I'd probably do what it sounds like you're trying to do, and what is generally bad practice in dask: nesting dask operations. I'd have dask tasks that, internally, generate dask objects, and probably even compute them. This is usually a bad idea, and easy to mess up, but I don't think we have a better option. |
Here's a semi-working example of generating the stacks on the cluster by using a dask Bag of points, and mapping a function over each point that loads data with stackstac. However, I don't think this will work—you'll also need to compute each point's timeseries on the cluster, within that function, for things to work at scale. Unfortunately, that makes dealing with metadata much less pleasant. import pystac_client
import pystac
import stackstac
import xarray as xr
import numpy as np
import planetary_computer as pc
import shapely.geometry
import pyproj
import dask
import dask.bag as db
import coiled
rng = np.random.default_rng()
bounds_wgs84 = -121, 31, -66, 48
points = rng.integers(bounds_wgs84[:2], bounds_wgs84[2:], (100, 2))
labels = np.arange(len(points))
stac_client = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1/")
stac_client = pystac_client.Client.open("https://planetarycomputer.microsoft.com/api/stac/v1/")
def timeseries_for_point(label: int, lon: float, lat: float) -> xr.DataArray:
point = shapely.geometry.Point(lon, lat)
items = pc.sign(
stac_client.search(
collections=["sentinel-2-l2a"],
intersects=point
)
)
if len(items) == 0:
return None
stack = stackstac.stack(
items,
assets=["B01", "B02", "B03"],
bounds_latlon=point.buffer(0.0001).envelope.bounds,
epsg=3857 # FIXME obviously don't use 3857 IRL
)
x, y = pyproj.Transformer.from_crs("epsg:4326", stack.attrs['crs'], always_xy=True).transform(lon, lat)
ts = stack.sel(x=x, y=y, method="nearest").assign_coords(label=label)
# Returning a dask object is not great here.
# Would prefer to return the actual data: `ts.compute(scheduler='threads')`.
# However, by returning the dask objects, you can get the metadata back
# to the client, without the full data.
# But, if you try to then combine/concatenate all these dask
# objects on the client in the end, it'll probably blow up at
# large numbers. Even if you do get them concatenated, optimizing
# the graph to compute them will blow up.
# So computing in-place here is probably the only thing that'll actually
# work at scale.
return ts
def resample_ts(ts: xr.DataArray) -> xr.DataArray:
# resampling to a common time frequency makes sense,
# otherwise all timeseries will be different lengths
return ts.resample(time="1M").mean()
bag = db.from_sequence(
np.concatenate([labels[:, None], points], axis=-1),
partition_size=1
)
tsb = (
bag
.starmap(timeseries_for_point)
.filter(lambda ts: ts is not None)
.map(resample_ts)
.map(dask.optimize) # reduce number of graph layers
.pluck(0) # because optimize always returns a tuple
)
cluster = coiled.Cluster(n_workers=10, package_sync=True)
client = cluster.get_client()
tss = tsb.compute()
# TODO want `coords='different'`, but mismatches are annoying
full_ts = xr.concat(tss, "label", coords='minimal', compat='override') |
Thank you for the recommendations and the code. I'll try it out. In the meanwhile I've been trying to track down what's making it so slow. I did several profiles; I've attached one (Quick note, the instructions you sent produce a link which does not work on speedscope - so I directly linked the gist.) On the other hand, when I run the exact same code on the same container, without running py-spy to profile it, it takes 17min. to execute the stacking. On the same docker container and code, when I run the container and Prefect flow locally it doesn't exhibit this difference between profiling and not-profiling, it executes in ~50sec. In addition, this difference in performance does not occur for any other part of the pipeline, including the actual execution of the dask graph, only for the stacking step. Do you have any suggestions as to what could be happening? The performance of a few hours I mentioned earlier was for 60-100 spatial points (again, 500 images average per point), running on ECS Fargate. Running locally in Docker it's <5min. Again, no other part of the pipeline shows such a difference. Note that I am explicitly setting the bounds very small: |
On another note: The additional memory burden of threads within a task would (in worst case) be the number of simultaneous reads * the extra data (to be thrown away) per read. Since I'm reading from a different image in each thread, I presume I don't need to copy any GDALDatasets. In my case the size of each chunk (50 images) is tiny (1MB), whereas each read task takes about 50sec. If we assume the extra data per read is on the same order of magnitude as the data I keep, each task would use 50MB max. So threads within the task would get me 1sec per chunk, for 50MB per chunk, a 50X speedup without modifying memory significantly for me. The other option is to keep each chunk/task single-threaded, and put lots of threads on dask. This would require a graph with 1 image per chunk, so 50X graph size, to get the same performance with 50 threads per worker, and the worker still uses 50X memory. However, this causes everything on my cluster to use 50 threads per worker, instead of this specific operation where I know the execution time is very large in comparison to the memory usage. I've experimented with second option, and it causes memory pressure since other operations are also executing on the cluster (eg. zarr reads, rechunking, resampling). I get out-of-memory error even doing nthreads = 2 * ncores instead of nthreads = ncores. It also makes the task graph very large (>300K tasks). A similar logic applies if I want a very large number of disjoint spatial points instead of a time-series. Please let me know if I've got anything wrong here, or there's another better option. |
I had a brief peak at the profile and noticed our Zooming a bit in, most/all of the time is spent importing a This might be a red herring and just an artifact of profiling... @gjoseph92 might have some thoughts here as well |
I'm running this in a container on ECS Fargate, and these files are installed in the building of the container image. |
I still thinks this seems much much slower than it should be (about an order of magnitude). I made a simple test to do something like you're doing (10 points, 500 items per point, 3 bands) and ran it on the planetary computer hub jupyterlab. The stacking step took me 2 seconds compared to your 50: This, plus the enormous variation in times in different environments, makes me think the problem isn't related to stackstac, but something about the environment where you're running the code. Is it possible to pull out your stacking code to a much more isolated environment (say, a notebook running locally), outside of prefect, docker, dask computing stuff, etc., and see how long it takes there? One thing I notice in your profile is that you seem to have two different event loops running, which is always concerning:
I'm vaguely curious if the slowness could have something to do with the GIL, since you generally have a lot of threads. It seems that you have a lot of concurrency going on between Prefect and Dask. If you have Python things running in other threads that hold the GIL, it could significantly slow down the stacking. I also notice that most of the time in Another reason the slowdowns at these particular points are interesting is that all of these release the GIL. If you do have GIL contention, then re-acquiring the GIL could make them appear to be slow (the reprojection itself would be very fast, but waiting for the GIL to become available again might take a while). See an explanation of this in dask/distributed#6325. I'm not saying this is the case, just a theoretical possibility. |
Thanks for the speed recommendations. I'll try timing the stacking code in a simpler environment. In the meantime, I tried the parallelized code (simplified), and got the same import error I was mentioning. The code: def timeseries_for_point(lat: float, lon: float, bands) -> xr.DataArray:
items = stac_client.search(intersects=dict(type="Point", coordinates=(lon, lat)),
collections=["sentinel-s2-l2a-cogs"]).item_collection()
if len(items) == 0:
return None
stack = stackstac.stack(items, assets=bands, bounds=meter_bounds(lat, lon, 50, 50),
epsg=4326, chunksize=(50, 1, 50, 50)).rename({'x':'lon', 'y':'lat'})
return stack
...
bag = db.from_sequence(np.concatenate([lats, lons], axis=-1), partition_size=1)
tsb = bag.starmap(timeseries_for_point, bands=bands).compute() Once compute is called, the tasks get sent to the cluster, but they are only queued and never start. Looking at the scheduler logs, I get this error repeated over and over.
'connectors' is not a module at all, it's a package within my repository, the package that contains this code. No other tasks produce this error. I figured it had something to do with building a task graph inside a dask task. |
Much simpler than that. Your If you use package sync, the local package will be installed automatically on the cluster if you've If you're running your own docker image on the cluster, the package needs to be |
I see - I presumed that locally-defined functions would be interpreted by dask and sent to the scheduler. Is it accurate to say that dask only looks at the name of the function called on the task, and always assumes that name can be resolved at the worker with local code? It just seems odd because it implies that any user-defined function that I wanted to parallelize would need to be installed in a package on the cluster? Do I have that incorrect? |
This isn't really dask; this is just how pickling functions in Python works. More or less, if the function is defined in the same file as where you're calling If the function comes from an |
In my case, the function is defined in the same file where I'm calling |
The only thing I can think of, is that file |
Sorry, that was a very broad categorization of things that brushed over a lot of details about how Python actually works. There are cases, like this one where you're calling
I would instead recommend just installing the package on the cluster. (If you're using package sync, just |
The package suggestion makes sense, however this code is currently under active development (not using package sync), so I'm avoiding the deployment burden and de-packaging it for now. I also wanted to see if it would work. Interestingly, that did not solve the problem. I removed prefect entirely and have When I copy all the relevant code from It also runs without error if I import it but run it locally with I tried simplifying even more and running a dummy function This seems to suggest that moving any of the code out of one file produces an error. This seems odd, what am I missing? (In view of all this, I'll try installing the package. However, I don't want to always have to create and pip-install a package just to move a function for parallelizing into another file.) |
As I said,
So what you're seeing is exactly the behavior to expect. If the function is Running locally works because:
Agreed that this is annoying. However, this is unfortunately how serializing functions in Python works. This is one part of why package sync is so helpful. If you don't want to use package sync, then I'd suggest reading this documentation, especially sections on uploading local files: |
OK, thanks for the clarification. |
Thanks to your suggestions, I can confirm that parallelizing the I will get back to you on the isolated-environment profile of the As for the considerations I brought up regarding threading the reader task - is there a more appropriate place to discuss those? |
Great!
You could consider using dask-pyspy here, now that the
An issue on stackstac would be the best place for that. |
Great, I've opened an issue here: gjoseph92/stackstac#199
That looks really good - I'll try it and report back. |
Description
Use the cluster to build/store large/many task graph objects (xarray objects in my case), to allow the client to be small.
Is your feature request related to a problem? Please describe.
We're now using stackstac, made by Gabe at Coiled, to combine geotiff images from stac. It's very useful. Both for memory and speed, we want to parallelize the process of building the xarray objects (building the dask graphs), and storing the objects (and graphs) onto the Coiled cluster. This never works. We always get a dask error about modules not found (this doesn't happen when parallelizing any other functions).
More generally, We've never been able to get dask to parallelize the process of building task graphs - it always seems to require running on the client. This is a problem when you have lots of graphs, or they are very large, or the input data required to build them is very large, and you want to keep the client small. We can't always get around this problem by making the graph small using chunking (eg. with lots of stackstac xarray objects). Is it actually possible to use dask on a Coiled cluster to assist in building task graphs for xarray objects, and if so how?
The text was updated successfully, but these errors were encountered: