Skip to content

Commit

Permalink
#104 - Integrate parallel processing
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Oct 11, 2023
1 parent 3a1926d commit 97406df
Showing 1 changed file with 45 additions and 33 deletions.
78 changes: 45 additions & 33 deletions src/pytimetk/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from typing import Union, Optional, Callable, Tuple, List

from concurrent.futures import ThreadPoolExecutor, as_completed

from pytimetk.utils.checks import check_dataframe_or_groupby, check_date_column, check_value_column
from pytimetk.utils.parallel_helpers import conditional_tqdm

@pf.register_dataframe_method
def augment_rolling(
Expand All @@ -15,6 +18,8 @@ def augment_rolling(
window: Union[int, tuple, list] = 2,
min_periods: Optional[int] = None,
center: bool = False,
threads: int = 1,
show_progress: bool = True,
**kwargs,
) -> pd.DataFrame:
'''Apply one or more Series-based rolling functions and window sizes to one or more columns of a DataFrame.
Expand Down Expand Up @@ -50,6 +55,10 @@ def augment_rolling(
Minimum observations in the window to have a value. Defaults to the window size. If set, a value will be produced even if fewer observations are present than the window size.
center : bool, optional, default False
If `True`, the rolling window will be centered on the current value. For even-sized windows, the window will be left-biased. Otherwise, it uses a trailing window.
threads : int, optional, default 1
Number of threads to use for parallel processing. If `threads` is set to 1, parallel processing will be disabled. Set to -1 to use all available CPU cores.
show_progress : bool, optional, default True
If `True`, a progress bar will be displayed during parallel processing.
Returns
-------
Expand Down Expand Up @@ -82,7 +91,8 @@ def augment_rolling(
window_func = [
'mean', # Built-in mean function
('std', lambda x: x.std()) # Lambda function to compute standard deviation
]
],
threads = 2
)
)
display(rolled_df)
Expand Down Expand Up @@ -141,39 +151,14 @@ def augment_rolling(
if isinstance(data, pd.core.groupby.generic.DataFrameGroupBy):
group_names = data.grouper.names
grouped = data_copy.sort_values(by=[*group_names, date_column]).groupby(group_names)
else:
group_names = None
grouped = [([], data_copy.sort_values(by=[date_column]))]
with ThreadPoolExecutor(max_workers=threads) as executor:
futures = [executor.submit(_process_single_roll, group, value_column, window_func, window, min_periods, center, **kwargs) for _, group in grouped]
# Collect results from all threads
result_dfs = [future.result() for future in conditional_tqdm(as_completed(futures), total = len(futures), desc = "Processing rolling calculations", display=show_progress)]
else:
result_dfs = [_process_single_roll(data_copy, value_column, window_func, window, min_periods, center, **kwargs)]

# Apply Series-based rolling window functions
result_dfs = []
for _, group_df in grouped:
for value_col in value_column:
for window_size in window:
min_periods = window_size if min_periods is None else min_periods
for func in window_func:
if isinstance(func, tuple):
func_name, func = func
new_column_name = f"{value_col}_rolling_{func_name}_win_{window_size}"
group_df[new_column_name] = group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs).apply(func, raw=True)

elif isinstance(func, str):
new_column_name = f"{value_col}_rolling_{func}_win_{window_size}"
# Get the rolling function (like mean, sum, etc.) specified by `func` for the given column and window settings
rolling_function = getattr(group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs), func, None)
# Apply rolling function to data and store in new column
if rolling_function:
group_df[new_column_name] = rolling_function()
else:
raise ValueError(f"Invalid function name: {func}")
else:
raise TypeError(f"Invalid function type: {type(func)}")

result_dfs.append(group_df)

# Combine processed dataframes and sort by index
result_df = pd.concat(result_dfs).sort_index() # Sort by the original index

return result_df

# Monkey patch the method to pandas groupby objects
Expand Down Expand Up @@ -374,4 +359,31 @@ def rolling_apply(func, df, window_size, min_periods, center):
return result_df

# Monkey patch the method to pandas groupby objects
pd.core.groupby.generic.DataFrameGroupBy.augment_rolling_apply = augment_rolling_apply
pd.core.groupby.generic.DataFrameGroupBy.augment_rolling_apply = augment_rolling_apply



# UTILITIES
# ---------

def _process_single_roll(group_df, value_column, window_func, window, min_periods, center, **kwargs):
result_dfs = []
for value_col in value_column:
for window_size in window:
min_periods = window_size if min_periods is None else min_periods
for func in window_func:
if isinstance(func, tuple):
func_name, func = func
new_column_name = f"{value_col}_rolling_{func_name}_win_{window_size}"
group_df[new_column_name] = group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs).apply(func, raw=True)
elif isinstance(func, str):
new_column_name = f"{value_col}_rolling_{func}_win_{window_size}"
rolling_function = getattr(group_df[value_col].rolling(window=window_size, min_periods=min_periods, center=center, **kwargs), func, None)
if rolling_function:
group_df[new_column_name] = rolling_function()
else:
raise ValueError(f"Invalid function name: {func}")
else:
raise TypeError(f"Invalid function type: {type(func)}")
result_dfs.append(group_df)
return pd.concat(result_dfs)

0 comments on commit 97406df

Please sign in to comment.