Skip to content

Commit

Permalink
Merge branch 'dev' into enable_flash_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Oct 25, 2023
2 parents a339171 + e620eed commit f479082
Show file tree
Hide file tree
Showing 11 changed files with 853 additions and 1,629 deletions.
13 changes: 8 additions & 5 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "mambaforge-4.10"
python: "3.10"

python:
install:
- method: pip
path: .
extra_requirements:
- docs

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py

# Optionally build your docs in additional formats such as PDF
# formats:
# - pdf

conda:
environment: env_cpu.yml
4 changes: 2 additions & 2 deletions EventStream/data/pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __init__(self, config: PytorchDatasetConfig, split: str):

self.has_task = True

if len(list(task_dir.glob("{split}*.parquet"))) > 0:
if len(list(task_dir.glob(f"{split}*.parquet"))) > 0:
print(
f"Re-loading task data for {self.config.task_df_name} from {task_dir}:\n"
f"{', '.join([str(fp) for fp in task_dir.glob('{split}*.parquet')])}"
f"{', '.join([str(fp) for fp in task_dir.glob(f'{split}*.parquet')])}"
)
self.cached_data = pl.scan_parquet(task_dir / f"{split}*.parquet")
with open(task_info_fp) as f:
Expand Down
2 changes: 1 addition & 1 deletion EventStream/evaluation/general_generative_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class GenerateConfig:

do_overwrite: bool = False

optimization_config: OptimizationConfig = OptimizationConfig()
optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig())

task_df_name: str | None = None

Expand Down
2 changes: 1 addition & 1 deletion EventStream/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def to_dict(self) -> dict[str, Any]:
@classmethod
def from_dict(cls, *args, **kwargs) -> "StructuredTransformerConfig":
raw_from_dict = super().from_dict(*args, **kwargs)
if raw_from_dict.measurmeent_configs:
if raw_from_dict.measurement_configs:

Check warning on line 927 in EventStream/transformer/config.py

View check run for this annotation

Codecov / codecov/patch

EventStream/transformer/config.py#L927

Added line #L927 was not covered by tests
new_meas_configs = {}
for k, v in raw_from_dict.measurement_configs.items():
new_meas_configs[k] = MeasurementConfig.from_dict(v)
Expand Down
2 changes: 1 addition & 1 deletion EventStream/transformer/lightning_modules/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class FinetuneConfig:
},
}
)
optimization_config: OptimizationConfig = OptimizationConfig()
optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig())
data_config: dict[str, Any] | None = dataclasses.field(
default_factory=lambda: {
**{k: None for k in PytorchDatasetConfig().to_dict().keys()},
Expand Down
14 changes: 9 additions & 5 deletions EventStream/transformer/lightning_modules/generative_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,16 @@ class PretrainConfig:
},
}
)
optimization_config: OptimizationConfig = OptimizationConfig()
data_config: PytorchDatasetConfig = PytorchDatasetConfig()
pretraining_metrics_config: MetricsConfig = MetricsConfig(
include_metrics={Split.TRAIN: {MetricCategories.LOSS_PARTS: True}},
optimization_config: OptimizationConfig = dataclasses.field(default_factory=lambda: OptimizationConfig())
data_config: PytorchDatasetConfig = dataclasses.field(default_factory=lambda: PytorchDatasetConfig())
pretraining_metrics_config: MetricsConfig = dataclasses.field(
default_factory=lambda: MetricsConfig(
include_metrics={Split.TRAIN: {MetricCategories.LOSS_PARTS: True}},
)
)
final_validation_metrics_config: MetricsConfig = dataclasses.field(
default_factory=lambda: MetricsConfig(do_skip_all_metrics=False)
)
final_validation_metrics_config: MetricsConfig = MetricsConfig(do_skip_all_metrics=False)

trainer_config: dict[str, Any] = dataclasses.field(
default_factory=lambda: {
Expand Down
9 changes: 1 addition & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,7 @@ GitHub issue.
## Installation

Installation of the required dependencies can be done via pip with `pip install -e .` in the root directory of
the repository.

### Other Installation Instructions

Installation (this mode is deprecated) can also be done via conda with the `env.yml` file: `conda env create -n ${ENV_NAME} -f env.yml`

It can also be attempted via poetry, though that approach has issues at present in differentiating between cpu
and gpu machines.
the repository. To be able to run tests, use `pip install -e .[tests]`. To be able to build docs, use `pip install -e .[docs]`.

## Overview

Expand Down
Loading

0 comments on commit f479082

Please sign in to comment.