-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjusted join in flat reps to account for different timestamps with t… #107
base: main
Are you sure you want to change the base?
Changes from 1 commit
720e6cb
b006195
9e0acf7
22dca2d
435d968
5c6cb4b
69b99ce
bf453e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -46,6 +46,7 @@ def load_flat_rep( | |||||||||||||||||
do_update_if_missing: bool = True, | ||||||||||||||||||
task_df_name: str | None = None, | ||||||||||||||||||
do_cache_filtered_task: bool = True, | ||||||||||||||||||
overwrite_cache_filtered_task: bool = False, | ||||||||||||||||||
subjects_included: dict[str, set[int]] | None = None, | ||||||||||||||||||
) -> dict[str, pl.LazyFrame]: | ||||||||||||||||||
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints. | ||||||||||||||||||
|
@@ -67,14 +68,16 @@ def load_flat_rep( | |||||||||||||||||
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will | ||||||||||||||||||
try to update the stored flat representations to reflect these. If `False`, if information is | ||||||||||||||||||
missing, it will raise a `FileNotFoundError` instead. | ||||||||||||||||||
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task | ||||||||||||||||||
task_df_name: If specified, the flat representations loaded will be joined against the task | ||||||||||||||||||
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed | ||||||||||||||||||
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into | ||||||||||||||||||
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a | ||||||||||||||||||
specified path for this task, then the data will be loaded from there, rather than from the base | ||||||||||||||||||
dataset. | ||||||||||||||||||
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the | ||||||||||||||||||
relevant rows for the task, be cached to disk for faster re-use. | ||||||||||||||||||
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, | ||||||||||||||||||
the cached file will be loaded if exists. | ||||||||||||||||||
subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are | ||||||||||||||||||
used wholesale. | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -170,7 +173,7 @@ def load_flat_rep( | |||||||||||||||||
if task_df_name is not None: | ||||||||||||||||||
fn = fp.parts[-1] | ||||||||||||||||||
cached_fp = task_window_dir / fn | ||||||||||||||||||
if cached_fp.is_file(): | ||||||||||||||||||
if cached_fp.is_file() and not overwrite_cache_filtered_task: | ||||||||||||||||||
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features) | ||||||||||||||||||
if subjects_included.get(sp, None) is not None: | ||||||||||||||||||
subjects = list(set(subjects).intersection(subjects_included[sp])) | ||||||||||||||||||
|
@@ -182,7 +185,8 @@ def load_flat_rep( | |||||||||||||||||
if task_df_name is not None: | ||||||||||||||||||
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) | ||||||||||||||||||
|
||||||||||||||||||
df = df.join(filter_join_df, on=join_keys, how="inner") | ||||||||||||||||||
df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', | ||||||||||||||||||
strategy='forward' if '-' in window_size else 'backward') | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor the - df = filter_join_df.join_asof(df, by='subject_id', on='timestamp',
- strategy='forward' if '-' in window_size else 'backward')
+ strategy = 'forward' if '-' in window_size else 'backward'
+ try:
+ df = filter_join_df.join_asof(df, by='subject_id', on='timestamp', strategy=strategy)
+ except Exception as e:
+ # Handle or log the exception
+ raise e This refactoring adds error handling around the Committable suggestion
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
if do_cache_filtered_task: | ||||||||||||||||||
cached_fp.parent.mkdir(exist_ok=True, parents=True) | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure proper documentation for the new parameter.
The new parameter
overwrite_cache_filtered_task
should be included in the function's docstring to maintain comprehensive documentation.+ overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, the cached file will be loaded if exists.
Committable suggestion