Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusted join in flat reps to account for different timestamps with t… #107

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
10 changes: 7 additions & 3 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def load_flat_rep(
do_update_if_missing: bool = True,
task_df_name: str | None = None,
do_cache_filtered_task: bool = True,
overwrite_cache_filtered_task: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper documentation for the new parameter.

The new parameter overwrite_cache_filtered_task should be included in the function's docstring to maintain comprehensive documentation.

+        overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, the cached file will be loaded if exists.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
overwrite_cache_filtered_task: bool = False,
overwrite_cache_filtered_task: bool = False,

subjects_included: dict[str, set[int]] | None = None,
) -> dict[str, pl.LazyFrame]:
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints.
Expand All @@ -67,14 +68,16 @@ def load_flat_rep(
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will
try to update the stored flat representations to reflect these. If `False`, if information is
missing, it will raise a `FileNotFoundError` instead.
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task
task_df_name: If specified, the flat representations loaded will be joined against the task
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a
specified path for this task, then the data will be loaded from there, rather than from the base
dataset.
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the
relevant rows for the task, be cached to disk for faster re-use.
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
the cached file will be loaded if exists.
subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are
used wholesale.

Expand Down Expand Up @@ -170,7 +173,7 @@ def load_flat_rep(
if task_df_name is not None:
fn = fp.parts[-1]
cached_fp = task_window_dir / fn
if cached_fp.is_file():
if cached_fp.is_file() and not overwrite_cache_filtered_task:
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
Expand All @@ -182,7 +185,8 @@ def load_flat_rep(
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))

df = df.join(filter_join_df, on=join_keys, how="inner")
df = filter_join_df.join_asof(df, by='subject_id', on='timestamp',
strategy='forward' if '-' in window_size else 'backward')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the join_asof method to handle potential discrepancies in data alignment.

- df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', 
-                               strategy='forward' if '-' in window_size else 'backward')
+ strategy = 'forward' if '-' in window_size else 'backward'
+ try:
+     df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', strategy=strategy)
+ except Exception as e:
+     # Handle or log the exception
+     raise e

This refactoring adds error handling around the join_asof method, which can be crucial when dealing with large datasets and complex joins.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
df = filter_join_df.join_asof(df, by='subject_id', on='timestamp',
strategy='forward' if '-' in window_size else 'backward')
strategy = 'forward' if '-' in window_size else 'backward'
try:
df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', strategy=strategy)
except Exception as e:
# Handle or log the exception
raise e


if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
Expand Down
Loading