-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzarr_python_benchmark_read.py
executable file
·80 lines (66 loc) · 2.63 KB
/
zarr_python_benchmark_read.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
import numpy as np
import timeit
import asyncio
import click
from functools import wraps
import sys
import zarr
from zarr.storage import LocalStore, RemoteStore
from zarr.core.indexing import BlockIndexer
from zarr.core.buffer import default_buffer_prototype
zarr.config.set({
"async.concurrency": 10, # None is too much memory
})
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
@click.command()
@coro
@click.argument('path', type=str)
@click.option('--concurrent_chunks', type=int, default=None, help='Number of concurrent async chunk reads. Ignored if --read-all is set')
@click.option('--read_all', is_flag=True, show_default=True, default=False, help='Read the entire array in one operation.')
async def main(path, concurrent_chunks, read_all):
# if "benchmark_compress_shard.zarr" in path:
# sys.exit(1)
if path.startswith("http"):
store = RemoteStore(url=path) # broken with zarr-python 3.0.0a0
else:
store = LocalStore(path)
dataset = zarr.open(store=store, mode='r')
domain_shape = dataset.shape
chunk_shape = dataset.chunks
print("Domain shape", domain_shape)
print("Chunk shape", chunk_shape)
num_chunks =[(domain + chunk_shape - 1) // chunk_shape for (domain, chunk_shape) in zip(domain_shape, chunk_shape)]
print("Number of chunks", num_chunks)
async def chunk_read(chunk_index):
indexer = BlockIndexer(chunk_index, dataset.shape, dataset.metadata.chunk_grid)
return await dataset._async_array._get_selection(
indexer=indexer, prototype=default_buffer_prototype()
)
start_time = timeit.default_timer()
if read_all:
print(dataset[:].shape)
elif concurrent_chunks is None:
async with asyncio.TaskGroup() as tg:
for chunk_index in np.ndindex(*num_chunks):
tg.create_task(chunk_read(chunk_index))
elif concurrent_chunks == 1:
for chunk_index in np.ndindex(*num_chunks):
dataset.get_block_selection(chunk_index)
else:
semaphore = asyncio.Semaphore(concurrent_chunks)
async def chunk_read_concurrent_limit(chunk_index):
async with semaphore:
return await chunk_read(chunk_index)
async with asyncio.TaskGroup() as tg:
for chunk_index in np.ndindex(*num_chunks):
tg.create_task(chunk_read_concurrent_limit(chunk_index))
elapsed = timeit.default_timer() - start_time
elapsed_ms = elapsed * 1000.0
print(f"Decoded in {elapsed_ms:.2f}ms")
if __name__ == "__main__":
asyncio.run(main())