From 04f13d1c7aaa97eb6078b2fd6d02a38a89573b30 Mon Sep 17 00:00:00 2001 From: Sergio Oller Date: Sat, 9 Sep 2023 17:15:16 +0200 Subject: [PATCH] Allow to provide a callable pm_pbar. Closes #30 --- parmap/parmap.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/parmap/parmap.py b/parmap/parmap.py index a2cc1c3..f72148e 100644 --- a/parmap/parmap.py +++ b/parmap/parmap.py @@ -122,12 +122,14 @@ def _create_pool(kwargs): return parallel, pool, close_pool -def _do_pbar(async_result, num_tasks, chunksize, refresh_time=2, tqdm_options={}): +def _do_pbar( + async_result, num_tasks, chunksize, refresh_time=2, pbar_wrapper=tqdm.tqdm +): remaining = num_tasks # tqdm provides a progress bar. # the pbar needs to be updated with the increment on each # iteration. - with tqdm.tqdm(total=num_tasks, **tqdm_options) as pbar: + with pbar_wrapper(total=num_tasks) as pbar: while True: if async_result.ready(): pbar.update(remaining) @@ -153,11 +155,26 @@ def _get_default_chunksize(chunksize, pool, num_tasks): return chunksize -def _serial_map_or_starmap(function, iterable, args, kwargs, progress, - map_or_starmap): - if progress or isinstance(progress, dict): - tqdm_options = progress if isinstance(progress, dict) else {} - iterable = tqdm.tqdm(iterable, **tqdm_options) +def _prepare_pbar_wrapper(progress): + has_pbar = False + wrapper = None + if progress is True and HAVE_TQDM: + has_pbar = True + wrapper = tqdm.tqdm + elif isinstance(progress, dict) and HAVE_TQDM: + has_pbar = True + wrapper = partial(tqdm.tqdm, **progress) + elif isinstance(progress, callable): + has_pbar = True + wrapper = progress + return (has_pbar, wrapper) + + +def _serial_map_or_starmap( + function, iterable, args, kwargs, pbar_wrapper, map_or_starmap +): + if pbar_wrapper is not None: + iterable = pbar_wrapper(iterable) if map_or_starmap == "map": output = [function(*([item] + list(args)), **kwargs) for item in iterable] @@ -219,15 +236,16 @@ def _map_or_starmap(function, iterable, args, kwargs, map_or_starmap): kwargs = _deprecated_kwargs(kwargs, arg_newarg) chunksize = kwargs.pop("pm_chunksize", None) progress = kwargs.pop("pm_pbar", False) - progress = progress if HAVE_TQDM else False + (has_pbar, pbar_wrapper) = _prepare_pbar_wrapper(progress) parallel, pool, close_pool = _create_pool(kwargs) # Handle case: Execute sequentially: if not parallel: - return _serial_map_or_starmap(function, iterable, args, kwargs, - progress, map_or_starmap) + return _serial_map_or_starmap( + function, iterable, args, kwargs, pbar_wrapper, map_or_starmap + ) func_star = _get_helper_func(map_or_starmap) # Handle case: Without showing progress bar - if not progress: + if not has_pbar: try: result = pool.map_async(func_star, zip(repeat(function), @@ -267,8 +285,7 @@ def _map_or_starmap(function, iterable, args, kwargs, map_or_starmap): pool.close() # Progress bar: try: - tqdm_options = progress if isinstance(progress, dict) else {} - _do_pbar(result, num_tasks, chunksize, tqdm_options=tqdm_options) + _do_pbar(result, num_tasks, chunksize, pbar_wrapper=pbar_wrapper) finally: output = result.get() if close_pool: