-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes the slowdowns and bugs caused by the prior improved compute practices, but requires a nested tensor package. #90
Conversation
@mmcdermott what polars and numpy version were used for this branch? Building the dataset works fine with this branch. When running pretrain, I initially got a runtime error stemming from here: EventStreamGPT/EventStream/data/pytorch_dataset.py Lines 253 to 255 in 29c3b9f
The error being: When I change the numpy array dtype to object, it just hangs (it runs for an incredibly long time, and I haven't been patient enough to wait it out). I don't know whether safetensors or the nested_ragged_tensors are to blame (I assume the latter). This occurs when trying to run pretrain, sampling about 90,000 subjects from a full dataset of about 600,000. |
…eamML into using_ragged_tensors
@juancq , definitely don't feel obliged to wait it out when setting it to an object; that will be very slow and defeats the purpose of this change. One question, though; in your data do you have reason to believe that your subjects will have varying numbers of observations of static measurements per row? I believe that what is going on here is that something that is treated as a dense tensor is in reality a ragged one. |
I don't follow, what would be an example? This branch also has parts that break with certain polars versions. This makes it hard to test because I don't know which exact polar version to use. |
@mmcdermott I have sorted out polars issues and have gotten further testing this. My pretraining now hangs here:
When I kill the script, the stack trace is something along the following lines (in ragged tensors): If this is too cryptic, let me know and I'll rephrase or can post an issue on nested_ragged_tensors repo. |
Thank you @juancq -- I've been travelling quite a bit and been otherwise occupied for the last month and a half, but I'm trying to push a new major version of ESGPT that addresses these issues and the other memory issues. I assume your last comment still reflects the state of things with this change for you? |
@coderabbitai review |
WalkthroughThe recent changes in the EventStream project involve enhancing data handling, caching mechanisms, and error management across multiple files. These updates include refining file extension checks, improving exception handling for data conversion, restructuring caching mechanisms for efficiency, and aligning test cases with the updated data structures and logic. Changes
Assessment against linked issues
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 2
Configuration used: CodeRabbit UI
Files ignored due to path filters (1)
pyproject.toml
is excluded by:!**/*.toml
Files selected for processing (3)
- EventStream/data/config.py (1 hunks)
- EventStream/data/dataset_polars.py (2 hunks)
- EventStream/data/pytorch_dataset.py (7 hunks)
Additional comments: 11
EventStream/data/pytorch_dataset.py (9)
- 9-21: > 📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [1-17]
The imports are appropriate and align with the functionality implemented in this file. The use of
nested_ragged_tensors
is particularly noteworthy, as it directly relates to the PR's objective of handling ragged tensors more efficiently.
- 77-77: Ensure that
len(self.sparse_static_tensors)
correctly reflects the dataset's length. Ifsparse_static_tensors
is expected to be a complete representation of the dataset, this is fine. Otherwise, consider a more robust way to determine the dataset's length.- 133-223: The caching logic in
_cache_subset
and the subsequent saving of tensors and data stats are well-implemented. However, consider the following improvements for readability and efficiency:
- Extracting repeated code blocks into helper functions.
- Using more descriptive variable names for clarity.
- Adding comments to complex logic sections for better maintainability.
- 244-289: The construction and saving of dense and ragged tensors in
_cache_full_data
are crucial for the dataset's performance. Ensure that:
- The data types and shapes of tensors are correctly handled, especially when converting lists to numpy arrays and tensors.
- The use of
JointNestedRaggedTensorDict
for ragged tensors aligns with the expected data structure and performance requirements.
- 292-305: The method
fetch_tensors
efficiently loads the cached tensors and applies necessary configurations. However, consider the following:
- Ensure that the deletion of tensors (
start_time
,start_idx
,end_idx
,subject_id
) based on configuration flags does not inadvertently affect downstream processing.- Validate that all necessary tensors are loaded and correctly handled for different configurations.
- 353-357: The
__getitem__
method's implementation appears correct. Ensure that the slicing and indexing logic correctly handles edge cases, especially for datasets with varying sequence lengths or missing data.- 373-393: The
collate
method's handling of ragged tensors and conversion to dense tensors is critical. Verify that:
- The conversion logic (
torch.from_numpy(v.astype(int)).long()
,torch.from_numpy(v.astype(float)).float()
) correctly handles all data types present in the dataset.- The handling of masks (
event_mask
,dynamic_values_mask
) correctly identifies valid and missing data points.
- 401-408: In the
collate
method, ensure that the conversion of tensors to the appropriate types (float
,long
) and the handling ofnan
values (torch.nan_to_num
) are correctly applied to all relevant tensors. This is crucial for maintaining data integrity and preventing errors in downstream processing.- 244-289: The
__getitem__
andcollate
methods inConstructorPytorchDataset
are crucial for data loading and preprocessing. Ensure that:
- The logic for handling static and dynamic data, including padding and type conversion, is correctly implemented.
- The handling of task labels (
multi_class_classification
,binary_classification
,regression
) is robust and correctly converts labels to the appropriate tensor types.EventStream/data/config.py (1)
- 1040-1040: The modification to use a dictionary comprehension for file retrieval seems efficient. However, it would be beneficial to add a comment specifying the expected file extension for clarity and maintainability.
EventStream/data/dataset_polars.py (1)
- 13-13: Adding
defaultdict
from thecollections
module is a good choice for handling dynamic and potentially unknown keys in dictionaries. This import is used effectively in the exception handling within the_filter_col_inclusion
method.
try: | ||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||
except TypeError as e: | ||
incl_targets_by_type = defaultdict(list) | ||
for t in incl_targets: | ||
incl_targets_by_type[str(type(t))].append(t) | ||
|
||
by_type_summ = [] | ||
for tp, vals in incl_targets_by_type.items(): | ||
by_type_summ.append( | ||
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..." | ||
) | ||
|
||
by_type_summ = "\n".join(by_type_summ) | ||
|
||
raise ValueError( | ||
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}" | ||
) from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes made to the _filter_col_inclusion
method enhance its robustness by adding exception handling for type conversion issues. This is particularly useful when dealing with dynamic data types that may not always conform to expected formats. The detailed error message constructed from the incl_targets_by_type
dictionary provides clear insight into the nature of the conversion failure, which can significantly aid in debugging.
However, consider adding a comment explaining the purpose of creating a summary of conversion failures by type and values. This will help future maintainers understand the rationale behind these changes more quickly.
+ # Handle type conversion exceptions by summarizing conversion failures by type and values
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
try: | |
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | |
except TypeError as e: | |
incl_targets_by_type = defaultdict(list) | |
for t in incl_targets: | |
incl_targets_by_type[str(type(t))].append(t) | |
by_type_summ = [] | |
for tp, vals in incl_targets_by_type.items(): | |
by_type_summ.append( | |
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..." | |
) | |
by_type_summ = "\n".join(by_type_summ) | |
raise ValueError( | |
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}" | |
) from e | |
# Handle type conversion exceptions by summarizing conversion failures by type and values | |
try: | |
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | |
except TypeError as e: | |
incl_targets_by_type = defaultdict(list) | |
for t in incl_targets: | |
incl_targets_by_type[str(type(t))].append(t) | |
by_type_summ = [] | |
for tp, vals in incl_targets_by_type.items(): | |
by_type_summ.append( | |
f"{tp}: {len(vals)} values: {', '.join(str(x) for x in vals[:5])}..." | |
) | |
by_type_summ = "\n".join(by_type_summ) | |
raise ValueError( | |
f"Failed to convert incl_targets to {df.schema[col]}:\n{by_type_summ}" | |
) from e |
EventStream/data/pytorch_dataset.py
Outdated
data_as_lists = defaultdict(list) | ||
for ep in tqdm(range(self.config.cache_for_epochs), total=self.config.cache_for_epochs, leave=False): | ||
for it in tqdm(constructor_pyd, total=len(constructor_pyd)): | ||
items.append(it) | ||
|
||
logger.info("Collating data into dense tensors to cache.") | ||
global_batch = constructor_pyd.collate(items, do_convert_float_nans=False) | ||
|
||
tensors_to_cache = [] | ||
seen_keys = set() | ||
for k, T in global_batch.items(): | ||
if k.endswith("_mask") and k != "event_mask": | ||
continue | ||
if T is None: | ||
continue | ||
if isinstance(T, torch.Tensor): | ||
if k in seen_keys: | ||
raise KeyError(f"Duplicate tensor save key {k}!") | ||
tensors_to_cache.append((k, T)) | ||
seen_keys.add(k) | ||
elif isinstance(T, dict): | ||
for kk, TT in T.items(): | ||
if TT is None: | ||
continue | ||
elif not isinstance(TT, torch.Tensor): | ||
raise TypeError(f"Unrecognized tensor type {type(TT)} @ {k}/{kk}!") | ||
|
||
if kk in seen_keys: | ||
raise KeyError(f"Duplicate tensor save key {kk}!") | ||
tensors_to_cache.append((kk, TT)) | ||
seen_keys.add(kk) | ||
else: | ||
raise TypeError(f"Unrecognized tensor type {type(T)} @ {k}!") | ||
for k, val in it.items(): | ||
data_as_lists[k].append(val) | ||
|
||
logger.info("Constructing tensors to cache.") | ||
logger.info(f"Dataset keys: {data_as_lists.keys()}") | ||
|
||
sparse_1_keys = ["time_delta", "dynamic_indices", "dynamic_values", "dynamic_measurement_indices"] | ||
sparse_2_keys = ["static_indices", "static_measurement_indices"] | ||
dense_keys = [k for k in data_as_lists.keys() if k not in sparse_1_keys + sparse_2_keys] | ||
tensor_types = { | ||
"subject_id": np.int64, | ||
"static_indices": np.int64, | ||
"static_measurement_indices": np.int64, | ||
} | ||
|
||
for k, T in tqdm(tensors_to_cache, leave=False, desc="Caching..."): | ||
fp = self._full_data_config.tensorized_cached_dir / self.split / f"{k}.pt" | ||
fp.parent.mkdir(exist_ok=True, parents=True) | ||
st = datetime.now() | ||
logger.info(f"Caching tensor {k} of shape {T.shape} to {fp}...") | ||
torch.save(T, fp) | ||
logger.info(f"Done in {datetime.now() - st}") | ||
# Dense tensors | ||
dense_tensors = {} | ||
for k in dense_keys: | ||
logger.info(f"Collating {k} as a dense tensor") | ||
dense_tensors[k] = np.array(data_as_lists[k], dtype=tensor_types.get(k, np.float32)) | ||
|
||
fp = self._full_data_config.tensorized_cached_dir / self.split / "dense.npz" | ||
logger.info(f"Saving dense tensors to {fp}") | ||
save_file(dense_tensors, fp) | ||
|
||
# Ragged tensors | ||
logger.info(f"Constructing ragged tensors across {sparse_1_keys}") | ||
sparse_tensors_dict = {k: data_as_lists[k] for k in sparse_1_keys} | ||
sparse_tensors_dict["dynamic_values"] = [ | ||
[[v if v is not None else float("nan") for v in measurements] for measurements in events] | ||
for events in sparse_tensors_dict["dynamic_values"] | ||
] | ||
sparse_tensors = JointNestedRaggedTensorDict(sparse_tensors_dict) | ||
fp = self._full_data_config.tensorized_cached_dir / self.split / "sparse_dynamic.npz" | ||
logger.info(f"Saving sparse tensors to {fp}") | ||
sparse_tensors.save(fp) | ||
|
||
logger.info(f"Constructing ragged tensors across {sparse_2_keys}") | ||
sparse_tensors_dict = {k: data_as_lists[k] for k in sparse_2_keys} | ||
sparse_tensors = JointNestedRaggedTensorDict(sparse_tensors_dict) | ||
fp = self._full_data_config.tensorized_cached_dir / self.split / "sparse_static.npz" | ||
logger.info(f"Saving sparse tensors to {fp}") | ||
sparse_tensors.save(fp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method _build_task_cached_df
in ConstructorPytorchDataset
is complex and critical for handling task-specific data. Consider the following improvements:
- Simplify the logic by breaking down the method into smaller, more manageable functions.
- Add detailed comments explaining the purpose and logic of each step, especially for complex operations.
- Ensure that the handling of time-dependent columns and the construction of task-specific dataframes are correctly implemented and optimized for performance.
Yes, the state of things on my end are the same as of the time I wrote the last comment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Actionable comments outside the diff hunks (15)
EventStream/data/dataset_polars.py (1)
Line range hint
203-203
: Replace type comparison withisinstance()
for better practice and compatibility.- if type(qq) is Path: + if isinstance(qq, Path):EventStream/data/config.py (14)
Line range hint
117-117
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.static) is dict: + if isinstance(self.static, dict):
Line range hint
125-125
: Replace type comparison withisinstance()
for better practice and readability.- if type(v) is dict: + if isinstance(v, dict):
Line range hint
407-407
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.data_schema) is not list and self.data_schema is not None: + if not isinstance(self.data_schema, list) and self.data_schema is not None:
Line range hint
409-409
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.start_data_schema) is not list and self.start_data_schema is not None: + if not isinstance(self.start_data_schema, list) and self.start_data_schema is not None:
Line range hint
411-411
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.end_data_schema) is not list and self.end_data_schema is not None: + if not isinstance(self.end_data_schema, list) and self.end_data_schema is not None:
Line range hint
625-625
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.min_seq_len) is not int or self.min_seq_len < 0: + if not isinstance(self.min_seq_len, int) or self.min_seq_len < 0:
Line range hint
631-631
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.max_seq_len) is not int or self.max_seq_len < self.min_seq_len: + if not isinstance(self.max_seq_len, int) or self.max_seq_len < self.min_seq_len:
Line range hint
649-649
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.save_dir) is str and self.save_dir != omegaconf.MISSING: + if isinstance(self.save_dir, str) and self.save_dir != omegaconf.MISSING:
Line range hint
911-911
: Replace type comparison withisinstance()
for better practice and readability.- if type(val) is not dict: + if not isinstance(val, dict):
Line range hint
913-913
: Replace type comparison withisinstance()
for better practice and readability.- if type(val) is not dict: + if not isinstance(val, dict):
Line range hint
919-919
: Replace type comparison withisinstance()
for better practice and readability.- if type(val) is not dict: + if not isinstance(val, dict):
Line range hint
1922-1922
: Replace type comparison withisinstance()
for better practice and readability.- if type(self.save_dir) is str: + if isinstance(self.save_dir, str):
Line range hint
1931-1931
: Replace type comparison withisinstance()
for better practice and readability.- if type(val) is not dict: + if not isinstance(val, dict):
Line range hint
1958-1958
: Replace type comparison withisinstance()
for better practice and readability.- if type(val) is not dict: + if not isinstance(val, dict):
This is blocked by #104 |
…resource intensive manner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Out of diff range and nitpick comments (4)
EventStream/baseline/FT_task_baseline.py (3)
[!TIP]
Codebase VerificationSeveral files in the codebase still use the old function signature.
- No direct or indirect tests for the
load_flat_rep
function were found in the codebase.- Ensure that comprehensive tests are added to cover all branches and exception handling for
load_flat_rep
.Analysis chain
Line range hint
35-241
: Ensure comprehensive testing ofload_flat_rep
to cover all branches and exception handling.Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify comprehensive tests for `load_flat_rep`. # Test: Search for test cases covering all branches in `load_flat_rep`. rg --type python $'def test_load_flat_rep'Length of output: 74
Script:
#!/bin/bash # Description: Verify comprehensive tests for `load_flat_rep`. # Test: Search for test cases covering all branches in `load_flat_rep`. rg --type py 'def test_load_flat_rep'Length of output: 39
Script:
#!/bin/bash # Description: Verify comprehensive tests for `load_flat_rep`. # Test: Search for references to `load_flat_rep` in test files. rg 'load_flat_rep' --type py --glob '*test*.py'Length of output: 47
Script:
#!/bin/bash # Description: Verify comprehensive tests for `load_flat_rep`. # Test: Search for references to `FT_task_baseline.py` and `load_flat_rep` within the codebase. rg 'FT_task_baseline' --type py rg 'load_flat_rep' --type pyLength of output: 305
Script:
#!/bin/bash # Description: Display the contents of `scripts/sklearn_baseline.py` to check for testing logic. # Display the contents of the file. cat scripts/sklearn_baseline.pyLength of output: 798
Line range hint
241-241
: Useisinstance()
for type checking instead of comparing types directly.- if type(window_sizes) is not list: + if not isinstance(window_sizes, list):
Line range hint
241-241
: Useisinstance()
for type checking to ensure flexibility and compatibility.- if type(window_sizes) is not list: + if not isinstance(window_sizes, list):EventStream/data/dataset_polars.py (1)
Line range hint
204-204
: Useisinstance()
instead of direct type comparison.- if type(qq) is Path: + if isinstance(qq, Path):
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files ignored due to path filters (1)
pyproject.toml
is excluded by!**/*.toml
Files selected for processing (5)
- EventStream/baseline/FT_task_baseline.py (2 hunks)
- EventStream/data/dataset_base.py (2 hunks)
- EventStream/data/dataset_polars.py (3 hunks)
- EventStream/data/pytorch_dataset.py (8 hunks)
- tests/data/test_pytorch_dataset.py (7 hunks)
Files not summarized due to errors (1)
- EventStream/data/pytorch_dataset.py: Error: Message exceeds token limit
Additional Context Used
Ruff (2)
EventStream/baseline/FT_task_baseline.py (1)
241-241: Do not compare types, use
isinstance()
EventStream/data/dataset_polars.py (1)
204-204: Do not compare types, use
isinstance()
Additional comments not posted (16)
tests/data/test_pytorch_dataset.py (6)
13-13
: Import ofJointNestedRaggedTensorDict
added.This import is necessary for handling nested ragged tensor data structures, which aligns with the PR's objectives to handle complex tensor operations.
56-58
: Calculation of event time deltas.The logic for calculating time deltas between events is correct and efficiently uses list comprehensions. This is crucial for the time series data manipulations required in the dataset.
Also applies to: 66-68, 77-79
85-86
: Introduction oftime
andtime_delta
fields inDL_REP_DF
.These fields are essential for representing time series data in a structured format, which is beneficial for downstream processing and analysis.
315-361
: Setup and teardown methods inTestPytorchDataset
.The setup and teardown methods are well-structured, ensuring that temporary directories and files are correctly managed. This is crucial for maintaining test isolation and cleanliness.
Line range hint
363-424
: Implementation ofget_pyd
method.The method is implemented correctly to configure and retrieve a
PytorchDataset
instance based on the provided configuration. This flexibility is important for testing different configurations.
421-424
: Assertion logic intest_normalize_task
.The assertion checks that the normalized values match the expected results, which is crucial for verifying the correctness of the normalization logic.
EventStream/data/pytorch_dataset.py (4)
9-15
: Import ofJointNestedRaggedTensorDict
and related constants.These imports are necessary for handling complex data structures involving nested and ragged tensors, aligning with the PR's objectives to enhance data handling capabilities.
Line range hint
61-160
: Comprehensive restructuring ofPytorchDataset
class.The restructuring introduces more robust handling of configurations, vocabulary, and data shards. This enhancement is crucial for managing complex datasets and configurations in a scalable manner.
Line range hint
172-398
: Methods for reading and setting configurations and data.These methods are well-implemented to ensure robust data management and error handling, which is essential for maintaining data integrity and facilitating efficient data processing.
Line range hint
426-611
: Implementation of__getitem__
andcollate
methods.These methods are crucial for fetching and collating data in a format suitable for PyTorch models. The implementation handles various data types and structures effectively, ensuring compatibility with different model architectures.
EventStream/baseline/FT_task_baseline.py (1)
35-35
: Update import to reflect the new class name.EventStream/data/dataset_base.py (3)
21-21
: Ensure the new importpolars as pl
is utilized appropriately in the code.
24-24
: Ensure the new importJointNestedRaggedTensorDict
fromnested_ragged_tensors.ragged_numpy
is utilized appropriately in the code.
1367-1367
: The creation ofNRT_dir
is not followed by any operations within this directory in the provided code. Verify if this is intended or if additional code is missing.Verification successful
The creation of
NRT_dir
is followed by operations within this directory, such as creating directories and saving files. Therefore, the code appears to be functioning as intended.
NRT_fp.parent.mkdir(exist_ok=True, parents=True)
jnrt_dict.save(NRT_fp)
Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for usage of NRT_dir in the repository. rg --type py 'NRT_dir'Length of output: 189
Script:
#!/bin/bash # Description: Check for usage of NRT_fp in the repository. rg --type py 'NRT_fp'Length of output: 821
EventStream/data/dataset_polars.py (2)
726-743
: Enhance clarity by adding a comment explaining the purpose of summarizing conversion failures by type and values.
1410-1420
: The implementation oftime_delta
calculation inbuild_DL_cached_representation
looks correct and efficient.
@@ -1358,28 +1360,67 @@ def cache_deep_learning_representation( | |||
""" | |||
|
|||
logger.info("Caching DL representations") | |||
if subjects_per_output_file is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider specifying a default value for subjects_per_output_file
to avoid potential issues with undefined behavior.
- if subjects_per_output_file is None:
+ if subjects_per_output_file is None:
+ subjects_per_output_file = 100 # Default value, adjust as necessary
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
if subjects_per_output_file is None: | |
if subjects_per_output_file is None: | |
subjects_per_output_file = 100 # Default value, adjust as necessary |
|
||
split_cached_df = self._filter_col_inclusion(cached_df, {"subject_id": subjects}) | ||
self._write_df(split_cached_df, fp, do_overwrite=do_overwrite) | ||
shards_fp.write_text(json.dumps(shards)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using json.dump
directly with a file handle instead of write_text
for writing JSON data. This is more idiomatic and efficient.
- shards_fp.write_text(json.dumps(shards))
+ with open(shards_fp, 'w') as f:
+ json.dump(shards, f)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
shards_fp.write_text(json.dumps(shards)) | |
with open(shards_fp, 'w') as f: | |
json.dump(shards, f) |
logger.info(f"Skipping {NRT_fp} as it already exists.") | ||
else: | ||
logger.info(f"Caching NRT for {shard_key} to {NRT_fp}") | ||
# TODO(mmd): This breaks the API isolation a bit, as we assume polars here. But that's fine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment about API isolation breakage should be addressed or clarified to ensure the design is robust and maintainable.
Consider revising the architecture to maintain API isolation or provide a detailed justification for the current approach.
@@ -660,7 +660,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig): | |||
task_dfs = add_tasks_from(ESD.config.save_dir / "task_dfs") | |||
task_df = task_dfs[cfg.task_df_name] | |||
|
|||
task_type, normalized_label = ConstructorPytorchDataset.normalize_task( | |||
task_type, normalized_label = PytorchDataset.normalize_task( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding more detailed logging for better traceability during model training.
+ logger.debug(f"Task type: {task_type}, Label: {normalized_label}")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
task_type, normalized_label = PytorchDataset.normalize_task( | |
task_type, normalized_label = PytorchDataset.normalize_task( | |
logger.debug(f"Task type: {task_type}, Label: {normalized_label}") |
@juancq I know this update is long overdue, but the recent pushes use the nested ragged tensor code in a new way that should dramatically reduce the CPU memory burden during model training. You'll need to re-build the base dataset object first to produce the right cached files (though I may write a conversion script to ease that cost, if that would be helpful) but once that is done this should, with minimal to no impact on throughput, set things up to only pull the patient data from disk as needed rather than loading it all in memory at all. Note this requires updating to the latest version of nested_ragged_tensors as well. If you try it and find it useful or find issues with it, I'd be very appreciative and curious of your findings! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files ignored due to path filters (1)
pyproject.toml
is excluded by!**/*.toml
Files selected for processing (1)
- EventStream/data/types.py (1 hunks)
Files skipped from review due to trivial changes (1)
- EventStream/data/types.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (1)
- EventStream/data/pytorch_dataset.py (8 hunks)
Files not summarized due to errors (1)
- EventStream/data/pytorch_dataset.py: Error: Message exceeds token limit
Additional comments not posted (12)
EventStream/data/pytorch_dataset.py (12)
9-15
: Imports fromnested_ragged_tensors.ragged_numpy
are added to handle nested and ragged tensor data.
61-62
: Introduction of thePytorchDataset
class with detailed documentation.
132-156
: The constructor (__init__
) has been expanded with several new logging statements and method calls to set up the dataset based on the configuration. This includes reading vocabulary, shards, and patient descriptors, and applying filters based on sequence length and training subset size.
169-179
: Methodsread_vocabulary
andread_shards
are introduced to load configuration and data shard information from disk. These methods are crucial for setting up the dataset and ensuring that data is loaded efficiently.
185-215
: Theread_patient_descriptors
method has been significantly expanded to handle loading and processing of patient descriptors. This includes reading parquet files and handling potential duplicate subject entries, which could lead to data inconsistencies.
226-242
: Handling of task-specific dataframes and information has been added. This includes reading and validating task constraints and information from disk, and handling discrepancies in task information. This is critical for ensuring that the dataset is correctly set up for specific tasks.
247-287
: Complex data manipulation and joining operations are performed to align task-specific data with the main dataset. This includes handling of start and end indices for tasks, which is crucial for correct data slicing during model training.
296-318
: Theget_task_info
method is introduced to extract and normalize task information from the dataframe. This includes determining the type of task (classification, regression) and setting up vocabulary for tasks. This method is essential for configuring the dataset to handle different types of tasks correctly.
322-342
: Methodsfilter_to_min_seq_len
andfilter_to_subset
have been added to apply filters on the dataset based on minimum sequence length and subset size. These methods are important for ensuring that the dataset meets specific training requirements.
Line range hint
365-396
: The methodset_inter_event_time_stats
calculates statistics on inter-event times and handles cases where these times are invalid (<=0). This is crucial for ensuring data quality and consistency.
Line range hint
424-509
: The__getitem__
method and its helper_seeded_getitem
have been significantly expanded to handle dynamic and static data indices, and to apply subsequence sampling strategies based on the configuration. This is critical for preparing data for model input.
Line range hint
511-609
: Thecollate
method and its helper__dynamic_only_collate
have been updated to handle the collation of batch data, including handling of ragged tensor data and padding of static data fields. This is essential for preparing batches of data for model training.
@mmcdermott thanks for all the hard work. I tested this branch on my dataset. The previous bugs are gone. I am now seeing about a 7% runtime improvement per epoch and about 30% lower memory usage. |
Fantastic! Thanks so much @juancq . I'll do some final testing just to make sure there are no issues and plan to merge this branch in soon. Glad this has resolved your issues and induced other improvements besides. |
This should fix #73 as well