Skip to content

Commit

Permalink
Merged
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Apr 22, 2024
2 parents aa953ab + 9eead53 commit a51e695
Show file tree
Hide file tree
Showing 46 changed files with 3,735 additions and 4,862 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.10
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.10"
python-version: "3.11"

- name: Install packages
run: |
Expand Down
26 changes: 15 additions & 11 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import polars.selectors as cs
import wandb
from hydra.core.config_store import ConfigStore
from loguru import logger
from omegaconf import OmegaConf
from sklearn.decomposition import NMF, PCA
from sklearn.ensemble import RandomForestClassifier
Expand All @@ -31,11 +32,11 @@
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from ..data.dataset_polars import Dataset
from ..data.pytorch_dataset import PytorchDataset
from ..data.pytorch_dataset import ConstructorPytorchDataset
from ..tasks.profile import add_tasks_from
from ..utils import task_wrapper

pl.enable_string_cache(True)
pl.enable_string_cache()


def load_flat_rep(
Expand Down Expand Up @@ -187,6 +188,7 @@ def load_flat_rep(
if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
df.collect().write_parquet(cached_fp, use_pyarrow=True)
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)

df = df.select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
Expand Down Expand Up @@ -649,7 +651,7 @@ def eval_binary_classification(Y: np.ndarray, probs: np.ndarray) -> dict[str, fl


def train_sklearn_pipeline(cfg: SklearnConfig):
print(f"Saving config to {cfg.save_dir / 'config.yaml'}")
logger.info(f"Saving config to {cfg.save_dir / 'config.yaml'}")
cfg.save_dir.mkdir(exist_ok=True, parents=True)
OmegaConf.save(cfg, cfg.save_dir / "config.yaml")

Expand All @@ -658,7 +660,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig):
task_dfs = add_tasks_from(ESD.config.save_dir / "task_dfs")
task_df = task_dfs[cfg.task_df_name]

task_type, normalized_label = PytorchDataset.normalize_task(
task_type, normalized_label = ConstructorPytorchDataset.normalize_task(
pl.col(cfg.finetuning_task_label), task_df.schema[cfg.finetuning_task_label]
)

Expand All @@ -674,7 +676,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig):

# TODO(mmd): Window sizes may violate start_time constraints in task dfs!

print(f"Loading representations for {', '.join(cfg.feature_selector.window_sizes)}")
logger.info(f"Loading representations for {', '.join(cfg.feature_selector.window_sizes)}")
subjects_included = {}

if cfg.train_subset_size not in (None, "FULL"):
Expand Down Expand Up @@ -706,24 +708,26 @@ def train_sklearn_pipeline(cfg: SklearnConfig):
Xs_and_Ys = {}
for split in ("train", "tuning", "held_out"):
st = datetime.now()
print(f"Loading dataset for {split}")
logger.info(f"Loading dataset for {split}")
df = flat_reps[split].with_columns(normalized_label.alias(cfg.finetuning_task_label)).collect()

X = df.drop(["subject_id", "timestamp", cfg.finetuning_task_label])
Y = df[cfg.finetuning_task_label].to_numpy()
print(f"Done with {split} dataset with X of shape {X.shape} " f"(elapsed: {datetime.now() - st})")
logger.info(
f"Done with {split} dataset with X of shape {X.shape} " f"(elapsed: {datetime.now() - st})"
)
Xs_and_Ys[split] = (X, Y)

print("Initializing model!")
logger.info("Initializing model!")
model = cfg.get_model(dataset=ESD)

print("Fitting model!")
logger.info("Fitting model!")
model.fit(*Xs_and_Ys["train"])
print(f"Saving model to {cfg.save_dir}")
logger.info(f"Saving model to {cfg.save_dir}")
with open(cfg.save_dir / "model.pkl", mode="wb") as f:
pickle.dump(model, f)

print("Evaluating model!")
logger.info("Evaluating model!")
all_metrics = {}
for split in ("tuning", "held_out"):
X, Y = Xs_and_Ys[split]
Expand Down
141 changes: 138 additions & 3 deletions EventStream/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import dataclasses
import enum
import hashlib
import json
import random
from collections import OrderedDict, defaultdict
from collections.abc import Hashable, Sequence
Expand All @@ -14,6 +16,7 @@

import omegaconf
import pandas as pd
from loguru import logger

from ..utils import (
COUNT_OR_PROPORTION,
Expand Down Expand Up @@ -849,6 +852,10 @@ class PytorchDatasetConfig(JSONableMixin):
Traceback (most recent call last):
...
TypeError: train_subset_size is of unrecognized type <class 'str'>.
>>> import sys
>>> from loguru import logger
>>> logger.remove()
>>> _ = logger.add(sys.stdout, format="{message}")
>>> config = PytorchDatasetConfig(
... save_dir='./dataset',
... max_seq_len=256,
Expand All @@ -860,7 +867,7 @@ class PytorchDatasetConfig(JSONableMixin):
... task_df_name=None,
... do_include_start_time_min=False
... )
WARNING! train_subset_size is set, but train_subset_seed is not. Setting to...
train_subset_size is set, but train_subset_seed is not. Setting to...
>>> assert config.train_subset_seed is not None
"""

Expand All @@ -880,7 +887,19 @@ class PytorchDatasetConfig(JSONableMixin):
do_include_subject_id: bool = False
do_include_start_time_min: bool = False

# Trades off between speed/disk/mem and support
cache_for_epochs: int = 1

def __post_init__(self):
if self.cache_for_epochs is None:
self.cache_for_epochs = 1

if self.subsequence_sampling_strategy != "random" and self.cache_for_epochs > 1:
raise ValueError(
f"It does not make sense to cache for {self.cache_for_epochs} with non-random "
"subsequence sampling."
)

if self.seq_padding_side not in SeqPaddingSide.values():
raise ValueError(f"seq_padding_side invalid; must be in {', '.join(SeqPaddingSide.values())}")
if type(self.min_seq_len) is not int or self.min_seq_len < 0:
Expand All @@ -901,8 +920,11 @@ def __post_init__(self):
raise ValueError(f"If float, train_subset_size must be in (0, 1)! Got {frac}")
case int() | float() if (self.train_subset_seed is None):
seed = int(random.randint(1, int(1e6)))
print(f"WARNING! train_subset_size is set, but train_subset_seed is not. Setting to {seed}")
logger.warning(f"train_subset_size is set, but train_subset_seed is not. Setting to {seed}")
self.train_subset_seed = seed
case None | "FULL" if self.train_subset_seed is not None:
logger.info(f"Removing train subset seed as train subset size is {self.train_subset_size}")
self.train_subset_seed = None
case None | "FULL" | int() | float():
pass
case _:
Expand All @@ -920,6 +942,119 @@ def from_dict(cls, as_dict: dict) -> PytorchDatasetConfig:
as_dict["save_dir"] = Path(as_dict["save_dir"])
return cls(**as_dict)

@property
def vocabulary_config_fp(self) -> Path:
return self.save_dir / "vocabulary_config.json"

@property
def vocabulary_config(self) -> VocabularyConfig:
return VocabularyConfig.from_json_file(self.vocabulary_config_fp)

@property
def measurement_config_fp(self) -> Path:
return self.save_dir / "inferred_measurement_configs.json"

@property
def measurement_configs(self) -> dict[str, MeasurementConfig]:
with open(self.measurement_config_fp) as f:
measurement_configs = {k: MeasurementConfig.from_dict(v) for k, v in json.load(f).items()}
return {k: v for k, v in measurement_configs.items() if not v.is_dropped}

@property
def DL_reps_dir(self) -> Path:
return self.save_dir / "DL_reps"

@property
def cached_task_dir(self) -> Path | None:
if self.task_df_name is None:
return None
else:
return self.save_dir / "DL_reps" / "for_task" / self.task_df_name

@property
def raw_task_df_fp(self) -> Path | None:
if self.task_df_name is None:
return None
else:
return self.save_dir / "task_dfs" / f"{self.task_df_name}.parquet"

@property
def task_info_fp(self) -> Path | None:
if self.task_df_name is None:
return None
else:
return self.cached_task_dir / "task_info.json"

@property
def _data_parameters_and_hash(self) -> tuple[dict[str, Any], str]:
params = sorted(
(
"save_dir",
"max_seq_len",
"min_seq_len",
"seq_padding_side",
"subsequence_sampling_strategy",
"train_subset_size",
"train_subset_seed",
"task_df_name",
)
)

params_list = []
for p in params:
v = str(getattr(self, p))
if (p == "train_subset_seed") and (self.train_subset_size in ("FULL", None)):
v = None
params_list.append((p, v))

params = tuple(params_list)
h = hashlib.blake2b(digest_size=8)
h.update(str(params).encode())

return {k: v for k, v in params}, h.hexdigest()

@property
def tensorized_cached_dir(self) -> Path:
if self.task_df_name is None:
base_dir = self.DL_reps_dir / "tensorized_cached"
else:
base_dir = self.cached_task_dir

return base_dir / self._data_parameters_and_hash[1]

@property
def _cached_data_parameters_fp(self) -> Path:
return self.tensorized_cached_dir / "data_parameters.json"

def _cache_data_parameters(self):
self._cached_data_parameters_fp.parent.mkdir(exist_ok=True, parents=True)

with open(self._cached_data_parameters_fp, mode="w") as f:
logger.info(f"Saving data parameters to {self._cached_data_parameters_fp}")
json.dump(self._data_parameters_and_hash[0], f)

def tensorized_cached_files(self, split: str) -> dict[str, Path]:
if not (self.tensorized_cached_dir / split).is_dir():
return {}

all_files = {fp.stem: fp for fp in (self.tensorized_cached_dir / split).glob("*.pt")}
files_str = ", ".join(all_files.keys())

for param, need_keys in [
("do_include_start_time_min", ["start_time"]),
("do_include_subsequence_indices", ["start_idx", "end_idx"]),
("do_include_subject_id", ["subject_id"]),
]:
param_val = getattr(self, param)
for need_key in need_keys:
if param_val:
if need_key not in all_files.keys():
raise KeyError(f"Missing {need_key} but {param} is True! Have {files_str}")
elif need_key in all_files:
all_files.pop(need_key)

return all_files


@dataclasses.dataclass
class MeasurementConfig(JSONableMixin):
Expand Down Expand Up @@ -1633,7 +1768,7 @@ class DatasetConfig(JSONableMixin):
agg_by_time_scale: Aggregate events into temporal buckets at this frequency. Uses the string language
described here:
https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.groupby_dynamic.html
https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html
Raises:
ValueError: If configuration parameters are invalid (e.g., proportion parameters being > 1, etc.).
Expand Down
Loading

0 comments on commit a51e695

Please sign in to comment.