diff --git a/src/pytimetk/core/rolling.py b/src/pytimetk/core/rolling.py index c31f7c99..550cfd6f 100644 --- a/src/pytimetk/core/rolling.py +++ b/src/pytimetk/core/rolling.py @@ -4,7 +4,10 @@ from typing import Union, Optional, Callable, Tuple, List +from concurrent.futures import ThreadPoolExecutor, as_completed + from pytimetk.utils.checks import check_dataframe_or_groupby, check_date_column, check_value_column +from pytimetk.utils.parallel_helpers import conditional_tqdm @pf.register_dataframe_method def augment_rolling( @@ -15,6 +18,8 @@ def augment_rolling( window: Union[int, tuple, list] = 2, min_periods: Optional[int] = None, center: bool = False, + threads: int = 1, + show_progress: bool = True, **kwargs, ) -> pd.DataFrame: '''Apply one or more Series-based rolling functions and window sizes to one or more columns of a DataFrame. @@ -50,6 +55,10 @@ def augment_rolling( Minimum observations in the window to have a value. Defaults to the window size. If set, a value will be produced even if fewer observations are present than the window size. center : bool, optional, default False If `True`, the rolling window will be centered on the current value. For even-sized windows, the window will be left-biased. Otherwise, it uses a trailing window. + threads : int, optional, default 1 + Number of threads to use for parallel processing. If `threads` is set to 1, parallel processing will be disabled. Set to -1 to use all available CPU cores. + show_progress : bool, optional, default True + If `True`, a progress bar will be displayed during parallel processing. Returns ------- @@ -82,7 +91,8 @@ def augment_rolling( window_func = [ 'mean', # Built-in mean function ('std', lambda x: x.std()) # Lambda function to compute standard deviation - ] + ], + threads = 2 ) ) display(rolled_df) @@ -141,39 +151,14 @@ def augment_rolling( if isinstance(data, pd.core.groupby.generic.DataFrameGroupBy): group_names = data.grouper.names grouped = data_copy.sort_values(by=[*group_names, date_column]).groupby(group_names) - else: - group_names = None - grouped = [([], data_copy.sort_values(by=[date_column]))] + with ThreadPoolExecutor(max_workers=threads) as executor: + futures = [executor.submit(_process_single_roll, group, value_column, window_func, window, min_periods, center, **kwargs) for _, group in grouped] + # Collect results from all threads + result_dfs = [future.result() for future in conditional_tqdm(as_completed(futures), total = len(futures), desc = "Processing rolling calculations", display=show_progress)] + else: + result_dfs = [_process_single_roll(data_copy, value_column, window_func, window, min_periods, center, **kwargs)] - # Apply Series-based rolling window functions - result_dfs = [] - for _, group_df in grouped: - for value_col in value_column: - for window_size in window: - min_periods = window_size if min_periods is None else min_periods - for func in window_func: - if isinstance(func, tuple): - func_name, func = func - new_column_name = f"{value_col}_rolling_{func_name}_win_{window_size}" - group_df[new_column_name] = group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs).apply(func, raw=True) - - elif isinstance(func, str): - new_column_name = f"{value_col}_rolling_{func}_win_{window_size}" - # Get the rolling function (like mean, sum, etc.) specified by `func` for the given column and window settings - rolling_function = getattr(group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs), func, None) - # Apply rolling function to data and store in new column - if rolling_function: - group_df[new_column_name] = rolling_function() - else: - raise ValueError(f"Invalid function name: {func}") - else: - raise TypeError(f"Invalid function type: {type(func)}") - - result_dfs.append(group_df) - - # Combine processed dataframes and sort by index result_df = pd.concat(result_dfs).sort_index() # Sort by the original index - return result_df # Monkey patch the method to pandas groupby objects @@ -374,4 +359,31 @@ def rolling_apply(func, df, window_size, min_periods, center): return result_df # Monkey patch the method to pandas groupby objects -pd.core.groupby.generic.DataFrameGroupBy.augment_rolling_apply = augment_rolling_apply \ No newline at end of file +pd.core.groupby.generic.DataFrameGroupBy.augment_rolling_apply = augment_rolling_apply + + + +# UTILITIES +# --------- + +def _process_single_roll(group_df, value_column, window_func, window, min_periods, center, **kwargs): + result_dfs = [] + for value_col in value_column: + for window_size in window: + min_periods = window_size if min_periods is None else min_periods + for func in window_func: + if isinstance(func, tuple): + func_name, func = func + new_column_name = f"{value_col}_rolling_{func_name}_win_{window_size}" + group_df[new_column_name] = group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs).apply(func, raw=True) + elif isinstance(func, str): + new_column_name = f"{value_col}_rolling_{func}_win_{window_size}" + rolling_function = getattr(group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs), func, None) + if rolling_function: + group_df[new_column_name] = rolling_function() + else: + raise ValueError(f"Invalid function name: {func}") + else: + raise TypeError(f"Invalid function type: {type(func)}") + result_dfs.append(group_df) + return pd.concat(result_dfs) \ No newline at end of file