Skip to content

Commit

Permalink
Allow to provide a callable pm_pbar. Closes #30
Browse files Browse the repository at this point in the history
  • Loading branch information
zeehio committed Sep 9, 2023
1 parent f8bf88e commit 04f13d1
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions parmap/parmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ def _create_pool(kwargs):
return parallel, pool, close_pool


def _do_pbar(async_result, num_tasks, chunksize, refresh_time=2, tqdm_options={}):
def _do_pbar(
async_result, num_tasks, chunksize, refresh_time=2, pbar_wrapper=tqdm.tqdm
):
remaining = num_tasks
# tqdm provides a progress bar.
# the pbar needs to be updated with the increment on each
# iteration.
with tqdm.tqdm(total=num_tasks, **tqdm_options) as pbar:
with pbar_wrapper(total=num_tasks) as pbar:
while True:
if async_result.ready():
pbar.update(remaining)
Expand All @@ -153,11 +155,26 @@ def _get_default_chunksize(chunksize, pool, num_tasks):
return chunksize


def _serial_map_or_starmap(function, iterable, args, kwargs, progress,
map_or_starmap):
if progress or isinstance(progress, dict):
tqdm_options = progress if isinstance(progress, dict) else {}
iterable = tqdm.tqdm(iterable, **tqdm_options)
def _prepare_pbar_wrapper(progress):
has_pbar = False
wrapper = None
if progress is True and HAVE_TQDM:
has_pbar = True
wrapper = tqdm.tqdm
elif isinstance(progress, dict) and HAVE_TQDM:
has_pbar = True
wrapper = partial(tqdm.tqdm, **progress)
elif isinstance(progress, callable):
has_pbar = True
wrapper = progress
return (has_pbar, wrapper)


def _serial_map_or_starmap(
function, iterable, args, kwargs, pbar_wrapper, map_or_starmap
):
if pbar_wrapper is not None:
iterable = pbar_wrapper(iterable)
if map_or_starmap == "map":
output = [function(*([item] + list(args)), **kwargs)
for item in iterable]
Expand Down Expand Up @@ -219,15 +236,16 @@ def _map_or_starmap(function, iterable, args, kwargs, map_or_starmap):
kwargs = _deprecated_kwargs(kwargs, arg_newarg)
chunksize = kwargs.pop("pm_chunksize", None)
progress = kwargs.pop("pm_pbar", False)
progress = progress if HAVE_TQDM else False
(has_pbar, pbar_wrapper) = _prepare_pbar_wrapper(progress)
parallel, pool, close_pool = _create_pool(kwargs)
# Handle case: Execute sequentially:
if not parallel:
return _serial_map_or_starmap(function, iterable, args, kwargs,
progress, map_or_starmap)
return _serial_map_or_starmap(
function, iterable, args, kwargs, pbar_wrapper, map_or_starmap
)
func_star = _get_helper_func(map_or_starmap)
# Handle case: Without showing progress bar
if not progress:
if not has_pbar:
try:
result = pool.map_async(func_star,
zip(repeat(function),
Expand Down Expand Up @@ -267,8 +285,7 @@ def _map_or_starmap(function, iterable, args, kwargs, map_or_starmap):
pool.close()
# Progress bar:
try:
tqdm_options = progress if isinstance(progress, dict) else {}
_do_pbar(result, num_tasks, chunksize, tqdm_options=tqdm_options)
_do_pbar(result, num_tasks, chunksize, pbar_wrapper=pbar_wrapper)
finally:
output = result.get()
if close_pool:
Expand Down

0 comments on commit 04f13d1

Please sign in to comment.