Skip to content

Commit

Permalink
Added tuning subset size and seed to data config
Browse files Browse the repository at this point in the history
  • Loading branch information
pargaw committed Jul 11, 2024
1 parent 69b99ce commit bf453e1
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions EventStream/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,10 @@ class PytorchDatasetConfig(JSONableMixin):
training subset. If `None` or "FULL", then the full training data is used.
train_subset_seed: If the training data should be subsampled randomly, this specifies the seed for
that random subsampling.
tuning_subset_size: If the tuning data should be subsampled randomly, this specifies the size of the
tuning subset. If `None` or "FULL", then the full tuning data is used.
tuning_subset_seed: If the tuning data should be subsampled randomly, this specifies the seed for
that random subsampling.
task_df_name: If the raw dataset should be limited to a task dataframe view, this specifies the name
of the task dataframe, and indirectly the path on disk from where that task dataframe will be
read (save_dir / "task_dfs" / f"{task_df_name}.parquet").
Expand Down Expand Up @@ -873,6 +877,8 @@ class PytorchDatasetConfig(JSONableMixin):

train_subset_size: int | float | str = "FULL"
train_subset_seed: int | None = None
tuning_subset_size: int | float | str = "FULL"
tuning_subset_seed: int | None = None

task_df_name: str | None = None

Expand Down Expand Up @@ -907,6 +913,22 @@ def __post_init__(self):
pass
case _:
raise TypeError(f"train_subset_size is of unrecognized type {type(self.train_subset_size)}.")

match self.tuning_subset_size:
case int() as n if n < 0:
raise ValueError(f"If integral, tuning_subset_size must be positive! Got {n}")
case float() as frac if frac <= 0 or frac >= 1:
raise ValueError(f"If float, tuning_subset_size must be in (0, 1)! Got {frac}")
case int() | float() if (self.tuning_subset_seed is None):
seed = int(random.randint(1, int(1e6)))
print(f"WARNING! tuning_subset_size is set, but tuning_subset_seed is not. Setting to {seed}")
self.tuning_subset_seed = seed
case None | "FULL" | int() | float():
pass
case _:
raise TypeError(
f"tuning_subset_size is of unrecognized type {type(self.tuning_subset_size)}."
)

def to_dict(self) -> dict:
"""Represents this configuration object as a plain dictionary."""
Expand Down

0 comments on commit bf453e1

Please sign in to comment.